Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn import datasets | |
| from sklearn.cluster import KMeans | |
| import matplotlib.pyplot as plt | |
| import plotly.express as px | |
| import base64 | |
| import plotly.figure_factory as ff | |
| import plotly.graph_objects as go | |
| from scipy.spatial import ConvexHull | |
| from scipy.spatial import distance | |
| from sklearn.decomposition import PCA | |
| st.set_page_config(layout="wide") | |
| # JS hack to add a toggle button for the sidebar | |
| st.markdown(""" | |
| <style> | |
| .reportview-container .main .block-container { | |
| max-width: 100%; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Load iris dataset | |
| iris = datasets.load_iris() | |
| X = iris.data | |
| st.title('Understanding K-Means Clustering') | |
| tab1, tab2, about = st.tabs(["Basic ☕", "Advanced 🔬"," ℹ️ About"]) | |
| if "toggle" not in st.session_state: | |
| st.session_state.toggle = True | |
| toggle_button = st.button("Toggle Sidebar") | |
| if toggle_button: | |
| st.session_state.toggle = not st.session_state.toggle | |
| dmojis = ["0️⃣", "1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣", "6️⃣", "7️⃣", "8️⃣", "9️⃣"] | |
| # Initialize user_features and n_clusters_advanced outside of any condition | |
| user_features = [6.5, 3.5, 4.5, 1.5] | |
| n_clusters_advanced = 2 | |
| if st.session_state.toggle: | |
| # User Input on Sidebar | |
| st.sidebar.header('Input Your Flower Data') | |
| def user_input_features(): | |
| sepal_length = st.sidebar.slider('Sepal Length (cm)', 4.0, 8.0, 6.5) | |
| sepal_width = st.sidebar.slider('Sepal Width (cm)', 2.0, 4.5, 3.5) | |
| petal_length = st.sidebar.slider('Petal Length (cm)', 1.0, 7.0, 4.5) | |
| petal_width = st.sidebar.slider('Petal Width (cm)', 0.1, 2.5, 1.5) | |
| return [sepal_length, sepal_width, petal_length, petal_width] | |
| user_features = user_input_features() # Update the user_features variable when sliders change | |
| # Slider for Advanced in the sidebar | |
| st.sidebar.header('K-Means Parameters') | |
| n_clusters_advanced = st.sidebar.slider('Number of Clusters (K)', 1, 8, n_clusters_advanced) | |
| st.markdown(""" | |
| <style> | |
| .reportview-container .main .block-container { | |
| overflow: auto; | |
| height: 2000px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| with tab1: | |
| st.write(""" | |
| ### What is Clustering? | |
| ##### Clustering with K-Means is a machine learning concept like tidying a messy room by grouping similar items, but for data instead of physical objects. | |
| """) | |
| # Button to toggle PCA | |
| if st.button('Toggle PCA for Visualization'): | |
| st.session_state.use_pca = not st.session_state.use_pca | |
| # Check if 'use_pca' is already in the session state | |
| if 'use_pca' not in st.session_state: | |
| st.session_state.use_pca = True | |
| if st.session_state.use_pca: | |
| st.write(""" | |
| ##### 🧠 PCA (Principal Component Analysis) is like looking at a messy room from the best angle to see the most mess. It helps us see our data more clearly! | |
| """) | |
| # Apply PCA for dimensionality reduction | |
| pca = PCA(n_components=2) | |
| X_transformed = pca.fit_transform(X) | |
| user_features_transformed = pca.transform([user_features])[0] | |
| else: | |
| X_transformed = X[:, :2] # Just use the first two features for visualization | |
| user_features_transformed = user_features[:2] | |
| st.write(""" | |
| ### Visualizing Groups | |
| ##### Here are the groups from our tidying method. Each color has a number at its center, representing its group. | |
| """) | |
| # Create a DataFrame for easier plotting with plotly | |
| df_transformed = pd.DataFrame(X_transformed, columns=['Feature1', 'Feature2']) | |
| # K-Means Algorithm | |
| kmeans = KMeans(n_clusters=n_clusters_advanced) | |
| y_kmeans = kmeans.fit_predict(X_transformed) | |
| df_transformed['cluster'] = y_kmeans | |
| # Predict the cluster for the user input in the transformed space | |
| predicted_cluster = kmeans.predict([user_features_transformed]) | |
| # For tab1 | |
| fig = go.Figure() | |
| # Add shaded regions using convex hull | |
| for cluster in np.unique(y_kmeans): | |
| cluster_data = df_transformed[df_transformed['cluster'] == cluster] | |
| x_data = cluster_data['Feature1'].values | |
| y_data = cluster_data['Feature2'].values | |
| if len(cluster_data) > 2: # ConvexHull requires at least 3 points | |
| hull = ConvexHull(cluster_data[['Feature1', 'Feature2']]) | |
| fig.add_trace(go.Scatter(x=x_data[hull.vertices], y=y_data[hull.vertices], fill='toself', fillcolor=px.colors.qualitative.Set1[cluster], opacity=0.5, line=dict(width=0), showlegend=False)) | |
| # Add scatter plot based on PCA toggle | |
| if st.session_state.use_pca: | |
| fig.add_trace(go.Scatter(x=df_transformed['Feature1'], y=df_transformed['Feature2'], mode='markers', marker=dict(color=y_kmeans, colorscale=px.colors.qualitative.Set1), showlegend=False)) | |
| else: | |
| fig.add_trace(go.Scatter(x=df_transformed['Feature1'], y=df_transformed['Feature2'], mode='markers', marker=dict(color=y_kmeans, colorscale=px.colors.qualitative.Set1, symbol='square'), showlegend=False)) | |
| # Add user input as a star marker | |
| fig.add_trace(go.Scatter(x=[user_features_transformed[0]], y=[user_features_transformed[1]], mode='markers', marker=dict(symbol='star', size=30, color='white'))) | |
| # Add centroids with group numbers | |
| for i, coord in enumerate(kmeans.cluster_centers_): | |
| fig.add_annotation( | |
| x=coord[0], | |
| y=coord[1], | |
| text=dmojis[i+1], | |
| showarrow=True, | |
| font=dict(color='white', size=30) | |
| ) | |
| # Update layout | |
| fig.update_layout(width=1200, height=500) | |
| st.plotly_chart(fig) | |
| # Button to toggle PCA | |
| if st.button('Toggle PCA for Visualization',key=125): | |
| st.session_state.use_pca = not st.session_state.use_pca | |
| if st.session_state.use_pca: | |
| st.write(""" | |
| ##### 🧠 PCA (Principal Component Analysis) is like looking at a messy room from the best angle to see the most mess. It helps us see our data more clearly! | |
| """) | |
| st.write(f"##### Overlapping clusters mean some flowers are very similar and hard to tell apart just by looking at these features.") | |
| st.write(f"# Based on your flower data (⭐), it likely belongs to **Group {dmojis[predicted_cluster[0]+1]}**") | |
| # Closing Note | |
| st.write(""" | |
| ### Wrap Up | |
| ##### Just as sorting toys in a room, we group flowers by features; adjust the data to pick a flower and set how many boxes (groups) you want to use. | |
| """) | |
| with tab2: | |
| st.write(""" | |
| ## Advanced Overview of Clustering | |
| Clustering, in the context of machine learning, refers to the task of partitioning the dataset into groups, known as clusters. The aim is to segregate groups with similar traits and assign them into clusters. | |
| ### K-Means Algorithm | |
| The K-Means clustering method is an iterative method that tries to partition the dataset into \(K\) pre-defined distinct non-overlapping subgroups (clusters) where each data point belongs to only one group. | |
| Here's a brief rundown: | |
| 1. **Initialization**: Choose \(K\) initial centroids. (Centroids is a fancy term for 'the center of the cluster'.) | |
| 2. **Assignment**: Assign each data point to the nearest centroid. All the points assigned to a centroid form a cluster. | |
| 3. **Update**: Recompute the centroid of each cluster. | |
| 4. **Repeat**: Keep repeating steps 2 and 3 until the centroids no longer move too much. | |
| """) | |
| st.write("The mathematical goal is to minimize the within-cluster sum of squares. The formula is:") | |
| st.latex(r''' | |
| \mathrm{WCSS} = \sum_{i=1}^{K} \sum_{x \in C_i} \| x - \mu_i \|^2 | |
| ''') | |
| st.latex(r''' | |
| \begin{align*} | |
| \text{Where:} \\ | |
| & \mathrm{WCSS} \text{ is the within-cluster sum of squares we want to minimize.} \\ | |
| & K \text{ is the number of clusters.} \\ | |
| & C_i \text{ is the i-th cluster.} \\ | |
| & \mu_i \text{ is the centroid of the i-th cluster.} \\ | |
| & x \text{ is a data point in cluster } C_i. | |
| \end{align*} | |
| ''') | |
| st.write(""" | |
| The K-Means algorithm tries to find the best centroids such that the \( \mathrm{WCSS} \) is minimized. | |
| """) | |
| # Button to toggle PCA | |
| if st.button('Toggle PCA for Visualization', key=12): | |
| st.session_state.use_pca = not st.session_state.use_pca | |
| # Check if 'use_pca' is already in the session state | |
| if 'use_pca' not in st.session_state: | |
| st.session_state.use_pca = True | |
| if st.session_state.use_pca: | |
| st.write(""" | |
| ##### 🧠 PCA (Principal Component Analysis) is a mathematical technique that helps us view our data from the best perspective. It identifies the directions (principal components) that maximize variance, allowing us to see patterns and structures more clearly. | |
| """) | |
| # Apply PCA for dimensionality reduction | |
| pca = PCA(n_components=2) | |
| X_transformed = pca.fit_transform(X) | |
| user_features_transformed = pca.transform([user_features])[0] | |
| else: | |
| X_transformed = X[:, :2] # Just use the first two features for visualization | |
| user_features_transformed = user_features[:2] | |
| # K-Means Algorithm for Advanced Tab | |
| kmeans_advanced = KMeans(n_clusters=n_clusters_advanced) | |
| y_kmeans_advanced = kmeans_advanced.fit_predict(X_transformed) | |
| # Create a DataFrame for easier plotting with plotly | |
| df_transformed = pd.DataFrame(X_transformed, columns=['Feature1', 'Feature2']) | |
| df_transformed['cluster'] = y_kmeans_advanced | |
| fig_advanced = go.Figure() | |
| # Add shaded regions using convex hull | |
| for cluster in np.unique(y_kmeans_advanced): | |
| cluster_data = df_transformed[df_transformed['cluster'] == cluster] | |
| x_data = cluster_data['Feature1'].values | |
| y_data = cluster_data['Feature2'].values | |
| if len(cluster_data) > 2: # ConvexHull requires at least 3 points | |
| hull = ConvexHull(cluster_data[['Feature1', 'Feature2']]) | |
| fig_advanced.add_trace(go.Scatter(x=x_data[hull.vertices], y=y_data[hull.vertices], fill='toself', fillcolor=px.colors.qualitative.Set1[cluster], opacity=0.5, line=dict(width=0), showlegend=False)) | |
| # Add scatter plot based on PCA toggle | |
| fig_advanced.add_trace(go.Scatter(x=df_transformed['Feature1'], y=df_transformed['Feature2'], mode='markers', marker=dict(color=y_kmeans_advanced, colorscale=px.colors.qualitative.Set1), showlegend=False)) | |
| # Add user input as a star marker | |
| fig_advanced.add_trace(go.Scatter(x=[user_features_transformed[0]], y=[user_features_transformed[1]], mode='markers', marker=dict(symbol='star', size=30, color='white'))) | |
| # Add centroids with group numbers | |
| for i, coord in enumerate(kmeans_advanced.cluster_centers_): | |
| fig_advanced.add_annotation( | |
| x=coord[0], | |
| y=coord[1], | |
| text=dmojis[i+1], | |
| showarrow=True, | |
| font=dict(color='white', size=30) | |
| ) | |
| # Update layout | |
| fig_advanced.update_layout(width=1200, height=500) | |
| st.plotly_chart(fig_advanced) | |
| st.write(""" | |
| ### Interpretation | |
| The plot displays how data points are grouped into clusters. The big gray X marks represent the center of each cluster, known as centroids. The positioning of these centroids is determined by the mean of all data points in the cluster. | |
| Keep in mind that the positioning of these centroids is crucial, as they determine the grouping of data. The algorithm tries to place them in such a way that the distance between the data points and their respective centroid is minimized. | |
| **Feel free to adjust the number of clusters to see how data points get re-grouped!** | |
| """) | |
| with about: | |
| st.title("About") | |
| st.markdown(""" | |
| ## Created by **Mustafa Alhamad**. | |
| """) | |
| st.markdown('[<img src="https://www.iconpacks.net/icons/2/free-linkedin-logo-icon-2430-thumb.png" width="128" height="128"/>](https://www.linkedin.com/in/mustafa-al-hamad-975b67213/)', unsafe_allow_html=True) | |
| st.markdown('### Made with <img src="https://streamlit.io/images/brand/streamlit-logo-secondary-colormark-darktext.svg" width="512" height="512"/>', unsafe_allow_html=True) | |
| hide_streamlit_style = """ | |
| <style> | |
| [data-testid="stToolbar"] {visibility: hidden !important;} | |
| footer {visibility: hidden !important;} | |
| </style> | |
| """ | |
| st.markdown(hide_streamlit_style, unsafe_allow_html=True) |