xxnithicxx commited on
Commit
346590b
Β·
1 Parent(s): 65bf92e

Enhance KMeans analysis visualization using Plotly and update requirements

Browse files
__pycache__/vlai_template.cpython-313.pyc CHANGED
Binary files a/__pycache__/vlai_template.cpython-313.pyc and b/__pycache__/vlai_template.cpython-313.pyc differ
 
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
- import io
5
- import base64
6
 
7
  import vlai_template
8
 
@@ -66,116 +66,143 @@ def standardize_data(X):
66
  """Standardize data to have mean=0 and std=1"""
67
  return (X - np.mean(X, axis=0)) / np.std(X, axis=0)
68
 
69
- def create_simple_plot_svg(X_pca, wine_types, labels, centroids_pca, k, explained_var):
70
- """Create a simple SVG plot"""
71
- width, height = 1400, 700
72
- margin = 80
73
 
74
- # Split into two plots
75
- plot_width = (width - 3 * margin) // 2
76
- plot_height = height - 2 * margin
77
-
78
- # Normalize data to plot coordinates
79
- x_min, x_max = X_pca[:, 0].min(), X_pca[:, 0].max()
80
- y_min, y_max = X_pca[:, 1].min(), X_pca[:, 1].max()
81
-
82
- def norm_x(x, plot_x_offset=0):
83
- return plot_x_offset + margin + (x - x_min) / (x_max - x_min) * plot_width
84
-
85
- def norm_y(y):
86
- return margin + (1 - (y - y_min) / (y_max - y_min)) * plot_height
87
-
88
- svg_content = f'''
89
- <div class="plot-container" style="width: 100%; overflow-x: auto;">
90
- <svg width="{width}" height="{height}" viewBox="0 0 {width} {height}"
91
- style="width: 100%; height: auto; min-width: 800px;" xmlns="http://www.w3.org/2000/svg">
92
- <defs>
93
- <style>
94
- .title {{ font: bold 20px sans-serif; text-anchor: middle; }}
95
- .axis-label {{ font: 16px sans-serif; text-anchor: middle; }}
96
- .legend {{ font: 14px sans-serif; }}
97
- </style>
98
- </defs>
99
-
100
- <!-- Plot 1: Original Wine Types -->
101
- <g>
102
- <rect x="{margin}" y="{margin}" width="{plot_width}" height="{plot_height}"
103
- fill="none" stroke="black" stroke-width="1"/>
104
- <text x="{margin + plot_width//2}" y="{margin - 10}" class="title">
105
- Original Data (Red vs White Wine)
106
- </text>
107
- <text x="{margin + plot_width//2}" y="{height - 10}" class="axis-label">
108
- PC1 ({explained_var[0]:.1%} variance)
109
- </text>
110
- <text x="{margin - 30}" y="{margin + plot_height//2}" class="axis-label"
111
- transform="rotate(-90, {margin - 30}, {margin + plot_height//2})">
112
- PC2 ({explained_var[1]:.1%} variance)
113
- </text>
114
- '''
115
-
116
- # Plot original data points
117
- for i, (point, wine_type) in enumerate(zip(X_pca, wine_types)):
118
- color = "#d62728" if wine_type == "red" else "#1f77b4" # red or blue
119
- x, y = norm_x(point[0]), norm_y(point[1])
120
- svg_content += f'<circle cx="{x}" cy="{y}" r="2" fill="{color}" opacity="0.6"/>\n'
121
 
122
- # Plot 1 legend
123
- svg_content += f'''
124
- <circle cx="{margin + 10}" cy="{margin + 20}" r="4" fill="#d62728"/>
125
- <text x="{margin + 20}" y="{margin + 25}" class="legend">Red Wine</text>
126
- <circle cx="{margin + 10}" cy="{margin + 40}" r="4" fill="#1f77b4"/>
127
- <text x="{margin + 20}" y="{margin + 45}" class="legend">White Wine</text>
128
- </g>
129
-
130
- <!-- Plot 2: KMeans Clusters -->
131
- <g>
132
- '''
133
 
134
- plot2_x_offset = plot_width + margin
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- svg_content += f'''
137
- <rect x="{margin + plot2_x_offset}" y="{margin}" width="{plot_width}" height="{plot_height}"
138
- fill="none" stroke="black" stroke-width="1"/>
139
- <text x="{margin + plot2_x_offset + plot_width//2}" y="{margin - 10}" class="title">
140
- After KMeans Clustering (K={k})
141
- </text>
142
- <text x="{margin + plot2_x_offset + plot_width//2}" y="{height - 10}" class="axis-label">
143
- PC1 ({explained_var[0]:.1%} variance)
144
- </text>
145
- <text x="{margin + plot2_x_offset - 30}" y="{margin + plot_height//2}" class="axis-label"
146
- transform="rotate(-90, {margin + plot2_x_offset - 30}, {margin + plot_height//2})">
147
- PC2 ({explained_var[1]:.1%} variance)
148
- </text>
149
- '''
150
 
151
- # Colors for clusters
152
  cluster_colors = ["#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22"]
153
 
154
- # Plot clustered data points
155
- for i, (point, label) in enumerate(zip(X_pca, labels)):
156
- color = cluster_colors[label % len(cluster_colors)]
157
- x, y = norm_x(point[0], plot2_x_offset), norm_y(point[1])
158
- svg_content += f'<circle cx="{x}" cy="{y}" r="2" fill="{color}" opacity="0.6"/>\n'
 
 
 
 
 
 
 
 
 
 
159
 
160
- # Plot centroids
161
- for i, centroid in enumerate(centroids_pca):
162
- x, y = norm_x(centroid[0], plot2_x_offset), norm_y(centroid[1])
163
- svg_content += f'<g><circle cx="{x}" cy="{y}" r="6" fill="black" stroke="white" stroke-width="2"/>'
164
- svg_content += f'<text x="{x}" y="{y+2}" text-anchor="middle" class="legend" fill="white">{i}</text></g>\n'
 
 
 
 
 
 
 
 
 
 
165
 
166
- # Plot 2 legend
167
- for i in range(k):
168
- color = cluster_colors[i % len(cluster_colors)]
169
- svg_content += f'<circle cx="{margin + plot2_x_offset + 10}" cy="{margin + 20 + i*20}" r="4" fill="{color}"/>\n'
170
- svg_content += f'<text x="{margin + plot2_x_offset + 20}" y="{margin + 25 + i*20}" class="legend">Cluster {i}</text>\n'
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- svg_content += '''
173
- </g>
174
- </svg>
175
- </div>
176
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- return svg_content
179
 
180
  def run_kmeans_analysis(n_clusters, random_state):
181
  """Main function to run the KMeans analysis"""
@@ -202,12 +229,27 @@ def run_kmeans_analysis(n_clusters, random_state):
202
 
203
  # Create plot
204
  wine_types = df['wine_type'].tolist() if 'wine_type' in df.columns else ['unknown'] * len(df)
205
- plot_svg = create_simple_plot_svg(X_pca, wine_types, labels, centroids_pca, n_clusters, explained_var)
206
 
207
- return plot_svg
208
 
209
  except Exception as e:
210
- return f"<p style='color: red;'>Error: {str(e)}</p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # ──────────────────────────── Main ─────────────────────────
213
  with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, title="Wine Quality KMeans Demo") as demo:
@@ -245,7 +287,7 @@ with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, title="Wine
245
  """)
246
 
247
  with gr.Column(scale=7):
248
- output_plot = gr.HTML(label="πŸ“ˆ PCA Visualization & KMeans Results")
249
 
250
  run_btn.click(
251
  run_kmeans_analysis,
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
+ import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
 
7
  import vlai_template
8
 
 
66
  """Standardize data to have mean=0 and std=1"""
67
  return (X - np.mean(X, axis=0)) / np.std(X, axis=0)
68
 
69
+ def create_plotly_visualization(X_pca, wine_types, labels, centroids_pca, k, explained_var):
70
+ """Create a plotly visualization with two subplots"""
 
 
71
 
72
+ # Create subplots
73
+ fig = make_subplots(
74
+ rows=1, cols=2,
75
+ subplot_titles=(
76
+ "Original Data (Red vs White Wine)",
77
+ f"After KMeans Clustering (K={k})"
78
+ ),
79
+ horizontal_spacing=0.1
80
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # Plot 1: Original Wine Types
83
+ red_mask = np.array(wine_types) == "red"
84
+ white_mask = np.array(wine_types) == "white"
 
 
 
 
 
 
 
 
85
 
86
+ # Add red wine points
87
+ if np.any(red_mask):
88
+ fig.add_trace(
89
+ go.Scatter(
90
+ x=X_pca[red_mask, 0],
91
+ y=X_pca[red_mask, 1],
92
+ mode='markers',
93
+ marker=dict(color='#d62728', size=4, opacity=0.6),
94
+ name='Red Wine',
95
+ showlegend=True
96
+ ),
97
+ row=1, col=1
98
+ )
99
 
100
+ # Add white wine points
101
+ if np.any(white_mask):
102
+ fig.add_trace(
103
+ go.Scatter(
104
+ x=X_pca[white_mask, 0],
105
+ y=X_pca[white_mask, 1],
106
+ mode='markers',
107
+ marker=dict(color='#1f77b4', size=4, opacity=0.6),
108
+ name='White Wine',
109
+ showlegend=True
110
+ ),
111
+ row=1, col=1
112
+ )
 
113
 
114
+ # Plot 2: KMeans Clusters
115
  cluster_colors = ["#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22"]
116
 
117
+ # Add cluster points
118
+ for i in range(k):
119
+ cluster_mask = labels == i
120
+ if np.any(cluster_mask):
121
+ fig.add_trace(
122
+ go.Scatter(
123
+ x=X_pca[cluster_mask, 0],
124
+ y=X_pca[cluster_mask, 1],
125
+ mode='markers',
126
+ marker=dict(color=cluster_colors[i % len(cluster_colors)], size=4, opacity=0.6),
127
+ name=f'Cluster {i}',
128
+ showlegend=True
129
+ ),
130
+ row=1, col=2
131
+ )
132
 
133
+ # Add centroids
134
+ fig.add_trace(
135
+ go.Scatter(
136
+ x=centroids_pca[:, 0],
137
+ y=centroids_pca[:, 1],
138
+ mode='markers+text',
139
+ marker=dict(color='black', size=12, line=dict(color='white', width=2)),
140
+ text=[str(i) for i in range(k)],
141
+ textfont=dict(color='white', size=10),
142
+ textposition="middle center",
143
+ name='Centroids',
144
+ showlegend=True
145
+ ),
146
+ row=1, col=2
147
+ )
148
 
149
+ # Update layout
150
+ fig.update_layout(
151
+ title="Wine Quality Dataset - KMeans Clustering with PCA",
152
+ title_x=0.5,
153
+ height=600,
154
+ plot_bgcolor='white',
155
+ paper_bgcolor='white',
156
+ font=dict(size=12),
157
+ showlegend=True,
158
+ legend=dict(
159
+ orientation="h",
160
+ yanchor="bottom",
161
+ y=-0.2,
162
+ xanchor="center",
163
+ x=0.5
164
+ )
165
+ )
166
 
167
+ # Update x and y axes labels and styling
168
+ fig.update_xaxes(
169
+ title_text=f"PC1 ({explained_var[0]:.1%} variance)",
170
+ showgrid=True,
171
+ gridcolor='lightgray',
172
+ gridwidth=1,
173
+ zeroline=True,
174
+ zerolinecolor='lightgray',
175
+ row=1, col=1
176
+ )
177
+ fig.update_xaxes(
178
+ title_text=f"PC1 ({explained_var[0]:.1%} variance)",
179
+ showgrid=True,
180
+ gridcolor='lightgray',
181
+ gridwidth=1,
182
+ zeroline=True,
183
+ zerolinecolor='lightgray',
184
+ row=1, col=2
185
+ )
186
+ fig.update_yaxes(
187
+ title_text=f"PC2 ({explained_var[1]:.1%} variance)",
188
+ showgrid=True,
189
+ gridcolor='lightgray',
190
+ gridwidth=1,
191
+ zeroline=True,
192
+ zerolinecolor='lightgray',
193
+ row=1, col=1
194
+ )
195
+ fig.update_yaxes(
196
+ title_text=f"PC2 ({explained_var[1]:.1%} variance)",
197
+ showgrid=True,
198
+ gridcolor='lightgray',
199
+ gridwidth=1,
200
+ zeroline=True,
201
+ zerolinecolor='lightgray',
202
+ row=1, col=2
203
+ )
204
 
205
+ return fig
206
 
207
  def run_kmeans_analysis(n_clusters, random_state):
208
  """Main function to run the KMeans analysis"""
 
229
 
230
  # Create plot
231
  wine_types = df['wine_type'].tolist() if 'wine_type' in df.columns else ['unknown'] * len(df)
232
+ plot_fig = create_plotly_visualization(X_pca, wine_types, labels, centroids_pca, n_clusters, explained_var)
233
 
234
+ return plot_fig
235
 
236
  except Exception as e:
237
+ # Return an empty plotly figure with error message
238
+ fig = go.Figure()
239
+ fig.add_annotation(
240
+ text=f"Error: {str(e)}",
241
+ xref="paper", yref="paper",
242
+ x=0.5, y=0.5,
243
+ showarrow=False,
244
+ font=dict(color="red", size=16)
245
+ )
246
+ fig.update_layout(
247
+ plot_bgcolor='white',
248
+ paper_bgcolor='white',
249
+ xaxis=dict(visible=False),
250
+ yaxis=dict(visible=False)
251
+ )
252
+ return fig
253
 
254
  # ──────────────────────────── Main ─────────────────────────
255
  with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, title="Wine Quality KMeans Demo") as demo:
 
287
  """)
288
 
289
  with gr.Column(scale=7):
290
+ output_plot = gr.Plot(label="πŸ“ˆ PCA Visualization & KMeans Results")
291
 
292
  run_btn.click(
293
  run_kmeans_analysis,
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio==5.38.0
2
  pandas
3
- numpy
 
 
1
  gradio==5.38.0
2
  pandas
3
+ numpy
4
+ plotly