Commit
Β·
346590b
1
Parent(s):
65bf92e
Enhance KMeans analysis visualization using Plotly and update requirements
Browse files- __pycache__/vlai_template.cpython-313.pyc +0 -0
- app.py +147 -105
- requirements.txt +2 -1
__pycache__/vlai_template.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/vlai_template.cpython-313.pyc and b/__pycache__/vlai_template.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
| 3 |
import numpy as np
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
|
| 7 |
import vlai_template
|
| 8 |
|
|
@@ -66,116 +66,143 @@ def standardize_data(X):
|
|
| 66 |
"""Standardize data to have mean=0 and std=1"""
|
| 67 |
return (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
| 68 |
|
| 69 |
-
def
|
| 70 |
-
"""Create a
|
| 71 |
-
width, height = 1400, 700
|
| 72 |
-
margin = 80
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
return plot_x_offset + margin + (x - x_min) / (x_max - x_min) * plot_width
|
| 84 |
-
|
| 85 |
-
def norm_y(y):
|
| 86 |
-
return margin + (1 - (y - y_min) / (y_max - y_min)) * plot_height
|
| 87 |
-
|
| 88 |
-
svg_content = f'''
|
| 89 |
-
<div class="plot-container" style="width: 100%; overflow-x: auto;">
|
| 90 |
-
<svg width="{width}" height="{height}" viewBox="0 0 {width} {height}"
|
| 91 |
-
style="width: 100%; height: auto; min-width: 800px;" xmlns="http://www.w3.org/2000/svg">
|
| 92 |
-
<defs>
|
| 93 |
-
<style>
|
| 94 |
-
.title {{ font: bold 20px sans-serif; text-anchor: middle; }}
|
| 95 |
-
.axis-label {{ font: 16px sans-serif; text-anchor: middle; }}
|
| 96 |
-
.legend {{ font: 14px sans-serif; }}
|
| 97 |
-
</style>
|
| 98 |
-
</defs>
|
| 99 |
-
|
| 100 |
-
<!-- Plot 1: Original Wine Types -->
|
| 101 |
-
<g>
|
| 102 |
-
<rect x="{margin}" y="{margin}" width="{plot_width}" height="{plot_height}"
|
| 103 |
-
fill="none" stroke="black" stroke-width="1"/>
|
| 104 |
-
<text x="{margin + plot_width//2}" y="{margin - 10}" class="title">
|
| 105 |
-
Original Data (Red vs White Wine)
|
| 106 |
-
</text>
|
| 107 |
-
<text x="{margin + plot_width//2}" y="{height - 10}" class="axis-label">
|
| 108 |
-
PC1 ({explained_var[0]:.1%} variance)
|
| 109 |
-
</text>
|
| 110 |
-
<text x="{margin - 30}" y="{margin + plot_height//2}" class="axis-label"
|
| 111 |
-
transform="rotate(-90, {margin - 30}, {margin + plot_height//2})">
|
| 112 |
-
PC2 ({explained_var[1]:.1%} variance)
|
| 113 |
-
</text>
|
| 114 |
-
'''
|
| 115 |
-
|
| 116 |
-
# Plot original data points
|
| 117 |
-
for i, (point, wine_type) in enumerate(zip(X_pca, wine_types)):
|
| 118 |
-
color = "#d62728" if wine_type == "red" else "#1f77b4" # red or blue
|
| 119 |
-
x, y = norm_x(point[0]), norm_y(point[1])
|
| 120 |
-
svg_content += f'<circle cx="{x}" cy="{y}" r="2" fill="{color}" opacity="0.6"/>\n'
|
| 121 |
|
| 122 |
-
# Plot 1
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
<text x="{margin + 20}" y="{margin + 25}" class="legend">Red Wine</text>
|
| 126 |
-
<circle cx="{margin + 10}" cy="{margin + 40}" r="4" fill="#1f77b4"/>
|
| 127 |
-
<text x="{margin + 20}" y="{margin + 45}" class="legend">White Wine</text>
|
| 128 |
-
</g>
|
| 129 |
-
|
| 130 |
-
<!-- Plot 2: KMeans Clusters -->
|
| 131 |
-
<g>
|
| 132 |
-
'''
|
| 133 |
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
'''
|
| 150 |
|
| 151 |
-
#
|
| 152 |
cluster_colors = ["#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22"]
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
for i
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
#
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
return
|
| 179 |
|
| 180 |
def run_kmeans_analysis(n_clusters, random_state):
|
| 181 |
"""Main function to run the KMeans analysis"""
|
|
@@ -202,12 +229,27 @@ def run_kmeans_analysis(n_clusters, random_state):
|
|
| 202 |
|
| 203 |
# Create plot
|
| 204 |
wine_types = df['wine_type'].tolist() if 'wine_type' in df.columns else ['unknown'] * len(df)
|
| 205 |
-
|
| 206 |
|
| 207 |
-
return
|
| 208 |
|
| 209 |
except Exception as e:
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
# ββββββββββββββββββββββββββββ Main βββββββββββββββββββββββββ
|
| 213 |
with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, title="Wine Quality KMeans Demo") as demo:
|
|
@@ -245,7 +287,7 @@ with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, title="Wine
|
|
| 245 |
""")
|
| 246 |
|
| 247 |
with gr.Column(scale=7):
|
| 248 |
-
output_plot = gr.
|
| 249 |
|
| 250 |
run_btn.click(
|
| 251 |
run_kmeans_analysis,
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import pandas as pd
|
| 3 |
import numpy as np
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
from plotly.subplots import make_subplots
|
| 6 |
|
| 7 |
import vlai_template
|
| 8 |
|
|
|
|
| 66 |
"""Standardize data to have mean=0 and std=1"""
|
| 67 |
return (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
| 68 |
|
| 69 |
+
def create_plotly_visualization(X_pca, wine_types, labels, centroids_pca, k, explained_var):
|
| 70 |
+
"""Create a plotly visualization with two subplots"""
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# Create subplots
|
| 73 |
+
fig = make_subplots(
|
| 74 |
+
rows=1, cols=2,
|
| 75 |
+
subplot_titles=(
|
| 76 |
+
"Original Data (Red vs White Wine)",
|
| 77 |
+
f"After KMeans Clustering (K={k})"
|
| 78 |
+
),
|
| 79 |
+
horizontal_spacing=0.1
|
| 80 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
# Plot 1: Original Wine Types
|
| 83 |
+
red_mask = np.array(wine_types) == "red"
|
| 84 |
+
white_mask = np.array(wine_types) == "white"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
# Add red wine points
|
| 87 |
+
if np.any(red_mask):
|
| 88 |
+
fig.add_trace(
|
| 89 |
+
go.Scatter(
|
| 90 |
+
x=X_pca[red_mask, 0],
|
| 91 |
+
y=X_pca[red_mask, 1],
|
| 92 |
+
mode='markers',
|
| 93 |
+
marker=dict(color='#d62728', size=4, opacity=0.6),
|
| 94 |
+
name='Red Wine',
|
| 95 |
+
showlegend=True
|
| 96 |
+
),
|
| 97 |
+
row=1, col=1
|
| 98 |
+
)
|
| 99 |
|
| 100 |
+
# Add white wine points
|
| 101 |
+
if np.any(white_mask):
|
| 102 |
+
fig.add_trace(
|
| 103 |
+
go.Scatter(
|
| 104 |
+
x=X_pca[white_mask, 0],
|
| 105 |
+
y=X_pca[white_mask, 1],
|
| 106 |
+
mode='markers',
|
| 107 |
+
marker=dict(color='#1f77b4', size=4, opacity=0.6),
|
| 108 |
+
name='White Wine',
|
| 109 |
+
showlegend=True
|
| 110 |
+
),
|
| 111 |
+
row=1, col=1
|
| 112 |
+
)
|
|
|
|
| 113 |
|
| 114 |
+
# Plot 2: KMeans Clusters
|
| 115 |
cluster_colors = ["#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22"]
|
| 116 |
|
| 117 |
+
# Add cluster points
|
| 118 |
+
for i in range(k):
|
| 119 |
+
cluster_mask = labels == i
|
| 120 |
+
if np.any(cluster_mask):
|
| 121 |
+
fig.add_trace(
|
| 122 |
+
go.Scatter(
|
| 123 |
+
x=X_pca[cluster_mask, 0],
|
| 124 |
+
y=X_pca[cluster_mask, 1],
|
| 125 |
+
mode='markers',
|
| 126 |
+
marker=dict(color=cluster_colors[i % len(cluster_colors)], size=4, opacity=0.6),
|
| 127 |
+
name=f'Cluster {i}',
|
| 128 |
+
showlegend=True
|
| 129 |
+
),
|
| 130 |
+
row=1, col=2
|
| 131 |
+
)
|
| 132 |
|
| 133 |
+
# Add centroids
|
| 134 |
+
fig.add_trace(
|
| 135 |
+
go.Scatter(
|
| 136 |
+
x=centroids_pca[:, 0],
|
| 137 |
+
y=centroids_pca[:, 1],
|
| 138 |
+
mode='markers+text',
|
| 139 |
+
marker=dict(color='black', size=12, line=dict(color='white', width=2)),
|
| 140 |
+
text=[str(i) for i in range(k)],
|
| 141 |
+
textfont=dict(color='white', size=10),
|
| 142 |
+
textposition="middle center",
|
| 143 |
+
name='Centroids',
|
| 144 |
+
showlegend=True
|
| 145 |
+
),
|
| 146 |
+
row=1, col=2
|
| 147 |
+
)
|
| 148 |
|
| 149 |
+
# Update layout
|
| 150 |
+
fig.update_layout(
|
| 151 |
+
title="Wine Quality Dataset - KMeans Clustering with PCA",
|
| 152 |
+
title_x=0.5,
|
| 153 |
+
height=600,
|
| 154 |
+
plot_bgcolor='white',
|
| 155 |
+
paper_bgcolor='white',
|
| 156 |
+
font=dict(size=12),
|
| 157 |
+
showlegend=True,
|
| 158 |
+
legend=dict(
|
| 159 |
+
orientation="h",
|
| 160 |
+
yanchor="bottom",
|
| 161 |
+
y=-0.2,
|
| 162 |
+
xanchor="center",
|
| 163 |
+
x=0.5
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
|
| 167 |
+
# Update x and y axes labels and styling
|
| 168 |
+
fig.update_xaxes(
|
| 169 |
+
title_text=f"PC1 ({explained_var[0]:.1%} variance)",
|
| 170 |
+
showgrid=True,
|
| 171 |
+
gridcolor='lightgray',
|
| 172 |
+
gridwidth=1,
|
| 173 |
+
zeroline=True,
|
| 174 |
+
zerolinecolor='lightgray',
|
| 175 |
+
row=1, col=1
|
| 176 |
+
)
|
| 177 |
+
fig.update_xaxes(
|
| 178 |
+
title_text=f"PC1 ({explained_var[0]:.1%} variance)",
|
| 179 |
+
showgrid=True,
|
| 180 |
+
gridcolor='lightgray',
|
| 181 |
+
gridwidth=1,
|
| 182 |
+
zeroline=True,
|
| 183 |
+
zerolinecolor='lightgray',
|
| 184 |
+
row=1, col=2
|
| 185 |
+
)
|
| 186 |
+
fig.update_yaxes(
|
| 187 |
+
title_text=f"PC2 ({explained_var[1]:.1%} variance)",
|
| 188 |
+
showgrid=True,
|
| 189 |
+
gridcolor='lightgray',
|
| 190 |
+
gridwidth=1,
|
| 191 |
+
zeroline=True,
|
| 192 |
+
zerolinecolor='lightgray',
|
| 193 |
+
row=1, col=1
|
| 194 |
+
)
|
| 195 |
+
fig.update_yaxes(
|
| 196 |
+
title_text=f"PC2 ({explained_var[1]:.1%} variance)",
|
| 197 |
+
showgrid=True,
|
| 198 |
+
gridcolor='lightgray',
|
| 199 |
+
gridwidth=1,
|
| 200 |
+
zeroline=True,
|
| 201 |
+
zerolinecolor='lightgray',
|
| 202 |
+
row=1, col=2
|
| 203 |
+
)
|
| 204 |
|
| 205 |
+
return fig
|
| 206 |
|
| 207 |
def run_kmeans_analysis(n_clusters, random_state):
|
| 208 |
"""Main function to run the KMeans analysis"""
|
|
|
|
| 229 |
|
| 230 |
# Create plot
|
| 231 |
wine_types = df['wine_type'].tolist() if 'wine_type' in df.columns else ['unknown'] * len(df)
|
| 232 |
+
plot_fig = create_plotly_visualization(X_pca, wine_types, labels, centroids_pca, n_clusters, explained_var)
|
| 233 |
|
| 234 |
+
return plot_fig
|
| 235 |
|
| 236 |
except Exception as e:
|
| 237 |
+
# Return an empty plotly figure with error message
|
| 238 |
+
fig = go.Figure()
|
| 239 |
+
fig.add_annotation(
|
| 240 |
+
text=f"Error: {str(e)}",
|
| 241 |
+
xref="paper", yref="paper",
|
| 242 |
+
x=0.5, y=0.5,
|
| 243 |
+
showarrow=False,
|
| 244 |
+
font=dict(color="red", size=16)
|
| 245 |
+
)
|
| 246 |
+
fig.update_layout(
|
| 247 |
+
plot_bgcolor='white',
|
| 248 |
+
paper_bgcolor='white',
|
| 249 |
+
xaxis=dict(visible=False),
|
| 250 |
+
yaxis=dict(visible=False)
|
| 251 |
+
)
|
| 252 |
+
return fig
|
| 253 |
|
| 254 |
# ββββββββββββββββββββββββββββ Main βββββββββββββββββββββββββ
|
| 255 |
with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, title="Wine Quality KMeans Demo") as demo:
|
|
|
|
| 287 |
""")
|
| 288 |
|
| 289 |
with gr.Column(scale=7):
|
| 290 |
+
output_plot = gr.Plot(label="π PCA Visualization & KMeans Results")
|
| 291 |
|
| 292 |
run_btn.click(
|
| 293 |
run_kmeans_analysis,
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
gradio==5.38.0
|
| 2 |
pandas
|
| 3 |
-
numpy
|
|
|
|
|
|
| 1 |
gradio==5.38.0
|
| 2 |
pandas
|
| 3 |
+
numpy
|
| 4 |
+
plotly
|