XPMaster commited on
Commit
a1e8fa8
·
1 Parent(s): d0528b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -194
app.py CHANGED
@@ -3,18 +3,15 @@ import pandas as pd
3
  import numpy as np
4
  from sklearn import datasets
5
  from sklearn.cluster import KMeans
6
- import matplotlib.pyplot as plt
7
  import plotly.express as px
8
- import base64
9
  import plotly.figure_factory as ff
10
  import plotly.graph_objects as go
11
  from scipy.spatial import ConvexHull
12
- from scipy.spatial import distance
13
  from sklearn.decomposition import PCA
14
 
15
  st.set_page_config(layout="wide")
16
 
17
- # JS hack to add a toggle button for the sidebar
18
  st.markdown("""
19
  <style>
20
  .reportview-container .main .block-container {
@@ -23,13 +20,21 @@ st.markdown("""
23
  </style>
24
  """, unsafe_allow_html=True)
25
 
 
 
 
 
 
 
 
 
 
26
  # Load iris dataset
27
  iris = datasets.load_iris()
28
  X = iris.data
29
 
30
  st.title('Understanding K-Means Clustering')
31
-
32
- tab1, tab2, about = st.tabs(["Basic ☕", "Advanced 🔬"," ℹ️ About"])
33
 
34
  if "toggle" not in st.session_state:
35
  st.session_state.toggle = True
@@ -41,7 +46,7 @@ if toggle_button:
41
 
42
  dmojis = ["0️⃣", "1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣"]
43
 
44
- # Initialize user_features and n_clusters_advanced outside of any condition
45
  user_features = [6.5, 3.5, 4.5, 1.5]
46
  n_clusters_advanced = 2
47
 
@@ -56,233 +61,101 @@ if st.session_state.toggle:
56
  petal_width = st.sidebar.slider('Petal Width (cm)', 0.1, 2.5, 1.5)
57
  return [sepal_length, sepal_width, petal_length, petal_width]
58
 
59
- user_features = user_input_features() # Update the user_features variable when sliders change
60
 
61
  # Slider for Advanced in the sidebar
62
  st.sidebar.header('K-Means Parameters')
63
  n_clusters_advanced = st.sidebar.slider('Number of Clusters (K)', 1, 8, n_clusters_advanced)
64
 
 
 
 
65
 
66
- st.markdown("""
67
- <style>
68
- .reportview-container .main .block-container {
69
- overflow: auto;
70
- height: 2000px;
71
- }
72
- </style>
73
- """, unsafe_allow_html=True)
74
-
75
- with tab1:
76
- st.write("""
77
- ### What is Clustering?
78
- ##### Clustering with K-Means is a machine learning concept like tidying a messy room by grouping similar items, but for data instead of physical objects.
79
- """)
80
-
81
- # Button to toggle PCA
82
- if st.button('Toggle PCA for Visualization'):
83
  st.session_state.use_pca = not st.session_state.use_pca
84
 
85
- # Check if 'use_pca' is already in the session state
86
  if 'use_pca' not in st.session_state:
87
  st.session_state.use_pca = True
88
 
89
  if st.session_state.use_pca:
90
- st.write("""
91
- ##### 🧠 PCA (Principal Component Analysis) is like looking at a messy room from the best angle to see the most mess. It helps us see our data more clearly!
92
- """)
93
- # Apply PCA for dimensionality reduction
94
  pca = PCA(n_components=2)
95
  X_transformed = pca.fit_transform(X)
96
  user_features_transformed = pca.transform([user_features])[0]
97
  else:
98
- X_transformed = X[:, :2] # Just use the first two features for visualization
99
  user_features_transformed = user_features[:2]
100
 
101
- st.write("""
102
- ### Visualizing Groups
103
- ##### Here are the groups from our tidying method. Each color has a number at its center, representing its group.
104
- """)
105
-
106
- # Create a DataFrame for easier plotting with plotly
107
- df_transformed = pd.DataFrame(X_transformed, columns=['Feature1', 'Feature2'])
108
-
109
- # K-Means Algorithm
110
  kmeans = KMeans(n_clusters=n_clusters_advanced)
111
  y_kmeans = kmeans.fit_predict(X_transformed)
 
 
 
112
  df_transformed['cluster'] = y_kmeans
113
 
114
- # Predict the cluster for the user input in the transformed space
115
  predicted_cluster = kmeans.predict([user_features_transformed])
116
-
117
- # For tab1
118
  fig = go.Figure()
119
-
120
  # Add shaded regions using convex hull
121
  for cluster in np.unique(y_kmeans):
122
  cluster_data = df_transformed[df_transformed['cluster'] == cluster]
123
- x_data = cluster_data['Feature1'].values
124
- y_data = cluster_data['Feature2'].values
125
- if len(cluster_data) > 2: # ConvexHull requires at least 3 points
126
  hull = ConvexHull(cluster_data[['Feature1', 'Feature2']])
127
- fig.add_trace(go.Scatter(x=x_data[hull.vertices], y=y_data[hull.vertices], fill='toself', fillcolor=px.colors.qualitative.Set1[cluster], opacity=0.5, line=dict(width=0), showlegend=False))
128
-
129
- # Add scatter plot based on PCA toggle
130
- if st.session_state.use_pca:
131
- fig.add_trace(go.Scatter(x=df_transformed['Feature1'], y=df_transformed['Feature2'], mode='markers', marker=dict(color=y_kmeans, colorscale=px.colors.qualitative.Set1), showlegend=False))
132
- else:
133
- fig.add_trace(go.Scatter(x=df_transformed['Feature1'], y=df_transformed['Feature2'], mode='markers', marker=dict(color=y_kmeans, colorscale=px.colors.qualitative.Set1, symbol='square'), showlegend=False))
134
-
135
- # Add user input as a star marker
136
- fig.add_trace(go.Scatter(x=[user_features_transformed[0]], y=[user_features_transformed[1]], mode='markers', marker=dict(symbol='star', size=30, color='white')))
137
-
138
- # Add centroids with group numbers
139
- for i, coord in enumerate(kmeans.cluster_centers_):
140
- fig.add_annotation(
141
- x=coord[0],
142
- y=coord[1],
143
- text=dmojis[i+1],
144
- showarrow=True,
145
- font=dict(color='white', size=30)
146
- )
147
-
148
- # Update layout
149
- fig.update_layout(width=1200, height=500)
150
- st.plotly_chart(fig)
151
-
152
- # Button to toggle PCA
153
- if st.button('Toggle PCA for Visualization',key=125):
154
- st.session_state.use_pca = not st.session_state.use_pca
155
-
156
- if st.session_state.use_pca:
157
- st.write("""
158
- ##### 🧠 PCA (Principal Component Analysis) is like looking at a messy room from the best angle to see the most mess. It helps us see our data more clearly!
159
- """)
160
-
161
- st.write(f"##### Overlapping clusters mean some flowers are very similar and hard to tell apart just by looking at these features.")
162
- st.write(f"# Based on your flower data (⭐), it likely belongs to **Group {dmojis[predicted_cluster[0]+1]}**")
163
-
164
- # Closing Note
165
- st.write("""
166
- ### Wrap Up
167
- ##### Just as sorting toys in a room, we group flowers by features; adjust the data to pick a flower and set how many boxes (groups) you want to use.
168
- """)
169
-
170
-
171
- with tab2:
172
- st.write("""
173
- ## Advanced Overview of Clustering
174
-
175
- Clustering, in the context of machine learning, refers to the task of partitioning the dataset into groups, known as clusters. The aim is to segregate groups with similar traits and assign them into clusters.
176
-
177
- ### K-Means Algorithm
178
-
179
- The K-Means clustering method is an iterative method that tries to partition the dataset into \(K\) pre-defined distinct non-overlapping subgroups (clusters) where each data point belongs to only one group.
180
-
181
- Here's a brief rundown:
182
-
183
- 1. **Initialization**: Choose \(K\) initial centroids. (Centroids is a fancy term for 'the center of the cluster'.)
184
- 2. **Assignment**: Assign each data point to the nearest centroid. All the points assigned to a centroid form a cluster.
185
- 3. **Update**: Recompute the centroid of each cluster.
186
- 4. **Repeat**: Keep repeating steps 2 and 3 until the centroids no longer move too much.
187
- """)
188
-
189
- st.write("The mathematical goal is to minimize the within-cluster sum of squares. The formula is:")
190
- st.latex(r'''
191
- \mathrm{WCSS} = \sum_{i=1}^{K} \sum_{x \in C_i} \| x - \mu_i \|^2
192
- ''')
193
-
194
- st.latex(r'''
195
- \begin{align*}
196
- \text{Where:} \\
197
- & \mathrm{WCSS} \text{ is the within-cluster sum of squares we want to minimize.} \\
198
- & K \text{ is the number of clusters.} \\
199
- & C_i \text{ is the i-th cluster.} \\
200
- & \mu_i \text{ is the centroid of the i-th cluster.} \\
201
- & x \text{ is a data point in cluster } C_i.
202
- \end{align*}
203
- ''')
204
-
205
- st.write("""
206
- The K-Means algorithm tries to find the best centroids such that the \( \mathrm{WCSS} \) is minimized.
207
- """)
208
-
209
- # Button to toggle PCA
210
- if st.button('Toggle PCA for Visualization', key=12):
211
- st.session_state.use_pca = not st.session_state.use_pca
212
-
213
- # Check if 'use_pca' is already in the session state
214
- if 'use_pca' not in st.session_state:
215
- st.session_state.use_pca = True
216
 
217
- if st.session_state.use_pca:
218
- st.write("""
219
- ##### 🧠 PCA (Principal Component Analysis) is a mathematical technique that helps us view our data from the best perspective. It identifies the directions (principal components) that maximize variance, allowing us to see patterns and structures more clearly.
220
- """)
221
- # Apply PCA for dimensionality reduction
222
- pca = PCA(n_components=2)
223
- X_transformed = pca.fit_transform(X)
224
- user_features_transformed = pca.transform([user_features])[0]
225
- else:
226
- X_transformed = X[:, :2] # Just use the first two features for visualization
227
- user_features_transformed = user_features[:2]
228
 
229
- # K-Means Algorithm for Advanced Tab
230
- kmeans_advanced = KMeans(n_clusters=n_clusters_advanced)
231
- y_kmeans_advanced = kmeans_advanced.fit_predict(X_transformed)
232
 
233
- # Create a DataFrame for easier plotting with plotly
234
- df_transformed = pd.DataFrame(X_transformed, columns=['Feature1', 'Feature2'])
235
- df_transformed['cluster'] = y_kmeans_advanced
236
 
237
- fig_advanced = go.Figure()
 
 
 
 
238
 
239
- # Add shaded regions using convex hull
240
- for cluster in np.unique(y_kmeans_advanced):
241
- cluster_data = df_transformed[df_transformed['cluster'] == cluster]
242
- x_data = cluster_data['Feature1'].values
243
- y_data = cluster_data['Feature2'].values
244
- if len(cluster_data) > 2: # ConvexHull requires at least 3 points
245
- hull = ConvexHull(cluster_data[['Feature1', 'Feature2']])
246
- fig_advanced.add_trace(go.Scatter(x=x_data[hull.vertices], y=y_data[hull.vertices], fill='toself', fillcolor=px.colors.qualitative.Set1[cluster], opacity=0.5, line=dict(width=0), showlegend=False))
247
 
248
- # Add scatter plot based on PCA toggle
249
- fig_advanced.add_trace(go.Scatter(x=df_transformed['Feature1'], y=df_transformed['Feature2'], mode='markers', marker=dict(color=y_kmeans_advanced, colorscale=px.colors.qualitative.Set1), showlegend=False))
250
 
251
- # Add user input as a star marker
252
- fig_advanced.add_trace(go.Scatter(x=[user_features_transformed[0]], y=[user_features_transformed[1]], mode='markers', marker=dict(symbol='star', size=30, color='white')))
 
 
 
 
 
253
 
254
- # Add centroids with group numbers
255
- for i, coord in enumerate(kmeans_advanced.cluster_centers_):
256
- fig_advanced.add_annotation(
257
- x=coord[0],
258
- y=coord[1],
259
- text=dmojis[i+1],
260
- showarrow=True,
261
- font=dict(color='white', size=30)
262
- )
263
 
264
- # Update layout
265
- fig_advanced.update_layout(width=1200, height=500)
266
- st.plotly_chart(fig_advanced)
267
 
 
 
 
268
  st.write("""
269
- ### Interpretation
270
-
271
- The plot displays how data points are grouped into clusters. The big gray X marks represent the center of each cluster, known as centroids. The positioning of these centroids is determined by the mean of all data points in the cluster.
272
-
273
- Keep in mind that the positioning of these centroids is crucial, as they determine the grouping of data. The algorithm tries to place them in such a way that the distance between the data points and their respective centroid is minimized.
274
-
275
- **Feel free to adjust the number of clusters to see how data points get re-grouped!**
276
- """)
277
-
278
 
279
- with about:
280
- st.title("About")
281
- st.markdown("""
282
- ## Created by **Mustafa Alhamad**.
283
- """)
284
- st.markdown('[<img src="https://www.iconpacks.net/icons/2/free-linkedin-logo-icon-2430-thumb.png" width="128" height="128"/>](https://www.linkedin.com/in/mustafa-al-hamad-975b67213/)', unsafe_allow_html=True)
285
- st.markdown('### Made with <img src="https://streamlit.io/images/brand/streamlit-logo-secondary-colormark-darktext.svg" width="512" height="512"/>', unsafe_allow_html=True)
286
 
287
  hide_streamlit_style = """
288
  <style>
 
3
  import numpy as np
4
  from sklearn import datasets
5
  from sklearn.cluster import KMeans
 
6
  import plotly.express as px
 
7
  import plotly.figure_factory as ff
8
  import plotly.graph_objects as go
9
  from scipy.spatial import ConvexHull
 
10
  from sklearn.decomposition import PCA
11
 
12
  st.set_page_config(layout="wide")
13
 
14
+ # Styles
15
  st.markdown("""
16
  <style>
17
  .reportview-container .main .block-container {
 
20
  </style>
21
  """, unsafe_allow_html=True)
22
 
23
+ st.markdown("""
24
+ <style>
25
+ .reportview-container .main .block-container {
26
+ overflow: auto;
27
+ height: 2000px;
28
+ }
29
+ </style>
30
+ """, unsafe_allow_html=True)
31
+
32
  # Load iris dataset
33
  iris = datasets.load_iris()
34
  X = iris.data
35
 
36
  st.title('Understanding K-Means Clustering')
37
+ tab1, tab2, about = st.tabs(["Basic ☕", "Advanced 🔬", " ℹ️ About"])
 
38
 
39
  if "toggle" not in st.session_state:
40
  st.session_state.toggle = True
 
46
 
47
  dmojis = ["0️⃣", "1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣"]
48
 
49
+ # Initialize user_features and n_clusters_advanced
50
  user_features = [6.5, 3.5, 4.5, 1.5]
51
  n_clusters_advanced = 2
52
 
 
61
  petal_width = st.sidebar.slider('Petal Width (cm)', 0.1, 2.5, 1.5)
62
  return [sepal_length, sepal_width, petal_length, petal_width]
63
 
64
+ user_features = user_input_features()
65
 
66
  # Slider for Advanced in the sidebar
67
  st.sidebar.header('K-Means Parameters')
68
  n_clusters_advanced = st.sidebar.slider('Number of Clusters (K)', 1, 8, n_clusters_advanced)
69
 
70
+ def plot_clusters(tab_name):
71
+ # This function will handle the plotting logic for both tabs to avoid repetition.
72
+ # It will return the appropriate plot based on the tab name.
73
 
74
+ # Toggle PCA for Visualization
75
+ if st.button(f'Toggle PCA for Visualization ({tab_name})', key=f'toggle_pca_{tab_name}'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  st.session_state.use_pca = not st.session_state.use_pca
77
 
78
+ # Check 'use_pca' in session state
79
  if 'use_pca' not in st.session_state:
80
  st.session_state.use_pca = True
81
 
82
  if st.session_state.use_pca:
 
 
 
 
83
  pca = PCA(n_components=2)
84
  X_transformed = pca.fit_transform(X)
85
  user_features_transformed = pca.transform([user_features])[0]
86
  else:
87
+ X_transformed = X[:, :2]
88
  user_features_transformed = user_features[:2]
89
 
90
+ # K-Means
 
 
 
 
 
 
 
 
91
  kmeans = KMeans(n_clusters=n_clusters_advanced)
92
  y_kmeans = kmeans.fit_predict(X_transformed)
93
+
94
+ # Create a DataFrame for plotting
95
+ df_transformed = pd.DataFrame(X_transformed, columns=['Feature1', 'Feature2'])
96
  df_transformed['cluster'] = y_kmeans
97
 
98
+ # Predict the cluster for user input
99
  predicted_cluster = kmeans.predict([user_features_transformed])
100
+
101
+ # Plot
102
  fig = go.Figure()
103
+
104
  # Add shaded regions using convex hull
105
  for cluster in np.unique(y_kmeans):
106
  cluster_data = df_transformed[df_transformed['cluster'] == cluster]
107
+ if len(cluster_data) > 2: # At least 3 points are needed
 
 
108
  hull = ConvexHull(cluster_data[['Feature1', 'Feature2']])
109
+ fig.add_trace(go.Scatter(x=cluster_data['Feature1'].values[hull.vertices], y=cluster_data['Feature2'].values[hull.vertices], fill='toself', fillcolor=px.colors.qualitative.Set1[cluster], opacity=0.5, line=dict(width=0), showlegend=False))
110
+
111
+ # Add data points
112
+ for cluster in np.unique(y_kmeans):
113
+ cluster_data = df_transformed[df_transformed['cluster'] == cluster]
114
+ fig.add_trace(go.Scatter(x=cluster_data['Feature1'], y=cluster_data['Feature2'], mode='markers', marker=dict(color=px.colors.qualitative.Set1[cluster], opacity=0.75), name=f"Cluster {dmojis[cluster]}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # Add user input point
117
+ fig.add_trace(go.Scatter(x=[user_features_transformed[0]], y=[user_features_transformed[1]], mode='markers', marker=dict(size=10, color='black'), name=f"Your Flower"))
 
 
 
 
 
 
 
 
 
118
 
119
+ # Add cluster centers
120
+ fig.add_trace(go.Scatter(x=kmeans.cluster_centers_[:, 0], y=kmeans.cluster_centers_[:, 1], mode='markers', marker=dict(size=10, color='red', symbol='cross'), name='Cluster Centers'))
 
121
 
122
+ fig.update_layout(title=f'{tab_name} Clustering', xaxis_title='Feature 1', yaxis_title='Feature 2', hovermode='closest', legend_title="Clusters & Points")
 
 
123
 
124
+ return fig
125
+
126
+ if tab1:
127
+ # Basic Tab
128
+ st.markdown("## Basic Clustering")
129
 
130
+ st.write("Using the flower data you've input on the sidebar, let's see how K-means clustering works in the basic mode with 2 clusters.")
 
 
 
 
 
 
 
131
 
132
+ n_clusters_basic = 2
 
133
 
134
+ # Plot
135
+ basic_fig = plot_clusters("Basic")
136
+ st.plotly_chart(basic_fig)
137
+
138
+ elif tab2:
139
+ # Advanced Tab
140
+ st.markdown("## Advanced Clustering")
141
 
142
+ st.write(f"Using the flower data you've input on the sidebar, let's see how K-means clustering works with {n_clusters_advanced} clusters.")
 
 
 
 
 
 
 
 
143
 
144
+ # Plot
145
+ advanced_fig = plot_clusters("Advanced")
146
+ st.plotly_chart(advanced_fig)
147
 
148
+ else:
149
+ # About Tab
150
+ st.markdown("## About this app")
151
  st.write("""
152
+ This app demonstrates the K-means clustering algorithm using the Iris dataset. The user can input their own data points to see where they fall within the clusters.
153
+ - **Basic Tab:** The data is clustered into 2 groups.
154
+ - **Advanced Tab:** The user can specify the number of clusters.
155
+ - **PCA Toggle:** This option allows for reducing the dataset's dimensions to 2D for easy visualization. The PCA transformed data is used for plotting purposes only, while the original data is used for clustering.
156
+ """)
 
 
 
 
157
 
158
+ # End of Streamlit app
 
 
 
 
 
 
159
 
160
  hide_streamlit_style = """
161
  <style>