kheejay88's picture
Update app.py
01fa893 verified
import streamlit as st
from sklearn.datasets import load_iris
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# --------------------- Streamlit App ---------------------
st.set_page_config(page_title="Unsupervised ML: Iris Clustering", layout="wide")
# Load and preprocess the Iris dataset
@st.cache_data
def load_data():
iris = load_iris()
X = iris.data
feature_names = iris.feature_names
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
return X, X_scaled, feature_names
X, X_scaled, feature_names = load_data()
# Perform K-Means clustering
@st.cache_data
def perform_clustering(X_scaled, n_clusters=3):
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clusters = kmeans.fit_predict(X_scaled)
return kmeans, clusters
kmeans, clusters = perform_clustering(X_scaled)
# Create a DataFrame with the clustering results
@st.cache_data
def create_clustered_dataframe(X, clusters, feature_names):
df = pd.DataFrame(X, columns=feature_names)
df['Cluster'] = clusters
# Assign meaningful labels to clusters based on analysis
cluster_labels = {0: 'Setosa-like', 1: 'Versicolor-like', 2: 'Virginica-like'}
df['Cluster Label'] = df['Cluster'].map(cluster_labels)
return df, cluster_labels
df, cluster_labels = create_clustered_dataframe(X, clusters, feature_names)
# βœ… App Title
st.title("🌸 Unsupervised Machine Learning: Iris Clustering App")
# Tabs for organization
tab1, tab2, tab3 = st.tabs(["🏠 About", "πŸ“Š Data Visualization", "πŸ”Ž Model Prediction"])
# ------------- About Tab -------------
with tab1:
st.header("About This App")
st.markdown("""
## **Overview**
This application demonstrates **unsupervised machine learning** using the Iris dataset.
The app clusters data points based on the features of iris flowers using the **K-Means clustering algorithm**.
After clustering, meaningful labels are assigned based on the cluster’s statistical properties.
## **How It Works**
1. **Data Preprocessing:**
- The dataset is standardized using `StandardScaler` to ensure uniform feature scaling.
2. **Clustering:**
- K-Means clustering is applied to group the data into **three clusters**.
- The number of clusters is based on the natural grouping of the Iris dataset.
3. **Cluster Labeling:**
- After clustering, each cluster is assigned a meaningful label based on its centroid properties and domain knowledge.
4. **Model Testing:**
- The app allows the user to enter custom feature values.
- The model predicts the cluster and assigns a meaningful label to the input data.
## **Dataset Information**
""")
st.write(pd.DataFrame(load_iris()['data'], columns=load_iris()['feature_names']).head())
st.markdown("""
The Iris dataset contains **150 samples** of iris flowers.
Each sample includes the following features:
- 🌸 Sepal Length (cm)
- 🌸 Sepal Width (cm)
- 🌸 Petal Length (cm)
- 🌸 Petal Width (cm)
The goal of clustering is to find natural patterns among these measurements.
""")
# ------------- Data Visualization Tab -------------
with tab2:
st.header("Data Visualization")
# βœ… Cluster distribution plot
st.subheader("Cluster Distribution")
fig, ax = plt.subplots()
sns.scatterplot(
x=df['sepal length (cm)'],
y=df['sepal width (cm)'],
hue=df['Cluster Label'],
palette='viridis',
s=100,
alpha=0.7,
ax=ax
)
plt.xlabel('Sepal Length (cm)')
plt.ylabel('Sepal Width (cm)')
st.pyplot(fig)
# βœ… Heatmap (Fixed by dropping non-numeric columns)
st.subheader("Heatmap of Feature Correlation")
numeric_df = df.drop(columns=["Cluster", "Cluster Label"]) # Drop non-numeric columns
fig, ax = plt.subplots(figsize=(6, 4))
sns.heatmap(numeric_df.corr(), annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
st.pyplot(fig)
# βœ… Box plots (Replaced pair plot for better clarity)
st.subheader("Box Plot of Features by Cluster")
fig, ax = plt.subplots(figsize=(10, 6))
sns.boxplot(x='Cluster Label', y='sepal length (cm)', data=df, palette='viridis', ax=ax)
plt.title("Sepal Length Distribution Across Clusters")
st.pyplot(fig)
fig, ax = plt.subplots(figsize=(10, 6))
sns.boxplot(x='Cluster Label', y='petal length (cm)', data=df, palette='viridis', ax=ax)
plt.title("Petal Length Distribution Across Clusters")
st.pyplot(fig)
# βœ… Feature importance (Tabular format with explanation)
st.subheader("Feature Importance (Based on Cluster Centers)")
feature_importance = pd.DataFrame(
kmeans.cluster_centers_,
columns=feature_names,
index=[f'Cluster {i}' for i in range(len(kmeans.cluster_centers_))]
)
st.dataframe(feature_importance)
st.markdown("""
**How to Interpret Positive and Negative Values:**
- **Positive Value:** The cluster center is positioned **above the mean** for that feature.
β†’ The cluster tends to have **higher values** for that feature.
- **Negative Value:** The cluster center is positioned **below the mean** for that feature.
β†’ The cluster tends to have **lower values** for that feature.
- **Magnitude:**
- Higher absolute values = Stronger influence of that feature in defining the cluster.
- Lower absolute values = Less influence of that feature in cluster formation.
""")
# ------------- Model Prediction Tab -------------
with tab3:
st.header("Predict Cluster for Custom Input")
# βœ… Collect user input for prediction
input_features = []
for feature in feature_names:
value = st.number_input(f"Enter {feature}", value=0.0, step=0.1)
input_features.append(value)
# βœ… Scale input data
input_scaled = StandardScaler().fit(X).transform([input_features])
if st.button("Predict Cluster"):
cluster = kmeans.predict(input_scaled)[0]
label = cluster_labels[cluster]
st.success(f"The predicted cluster is: **{label}**")
# βœ… Show cluster center distances with explanation
if st.checkbox("Show Cluster Distances"):
st.markdown("""
**What is Cluster Distance?**
- Cluster distance represents how close your custom input is to each cluster center.
- A smaller distance means your input is more similar to that cluster's typical values.
""")
distances = kmeans.transform(input_scaled)[0]
distance_df = pd.DataFrame(
distances,
index=[f'Cluster {i}' for i in range(len(distances))],
columns=["Distance"]
)
st.write(distance_df)
# βœ… Plot distances
fig, ax = plt.subplots()
sns.barplot(
x=distance_df.index,
y=distance_df["Distance"],
palette="viridis",
ax=ax
)
ax.set_title("Distance to Cluster Centers")
ax.set_ylabel("Distance")
st.pyplot(fig)
# --------------------- Footer ---------------------
st.markdown("---")
st.write("**Awesome 😎**")