kheejay88's picture
Update app.py
6315b8e verified
import streamlit as st
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
# Set page configuration
st.set_page_config(page_title="Unsupervised ML: Mall Customer Segmentation", layout="wide")
# --------------------- Load and preprocess the dataset ---------------------
@st.cache_data
def load_data():
dataset = load_dataset("kheejay88/mall_customers", split="train")
df = pd.DataFrame(dataset)
return df
df = load_data()
# Preprocess data
@st.cache_data
def preprocess_data(df):
features = ['Annual Income (k$)', 'Spending Score (1-100)']
X = df[features]
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
return X, X_scaled, features, scaler
X, X_scaled, features, scaler = preprocess_data(df)
# --------------------- Perform K-Means clustering ---------------------
@st.cache_data
def perform_clustering(X_scaled, n_clusters=5):
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
clusters = kmeans.fit_predict(X_scaled)
return kmeans, clusters
kmeans, clusters = perform_clustering(X_scaled)
# Add cluster labels to the dataframe
df['Cluster'] = clusters
cluster_labels = {i: f'Cluster {i}' for i in range(kmeans.n_clusters)}
df['Cluster Label'] = df['Cluster'].map(cluster_labels)
# --------------------- Sidebar for user input ---------------------
st.sidebar.header("User Input Features")
annual_income = st.sidebar.slider(
'Annual Income (k$)',
int(df['Annual Income (k$)'].min()),
int(df['Annual Income (k$)'].max()),
int(df['Annual Income (k$)'].mean())
)
spending_score = st.sidebar.slider(
'Spending Score (1-100)',
int(df['Spending Score (1-100)'].min()),
int(df['Spending Score (1-100)'].max()),
int(df['Spending Score (1-100)'].mean())
)
# --------------------- Predict cluster for user input ---------------------
def predict_cluster(annual_income, spending_score, kmeans, scaler):
input_data = pd.DataFrame([[annual_income, spending_score]], columns=features)
input_scaled = scaler.transform(input_data)
cluster = kmeans.predict(input_scaled)[0]
distances = kmeans.transform(input_scaled)[0]
return cluster, distances
# Prediction
user_cluster, distances = predict_cluster(annual_income, spending_score, kmeans, scaler)
# --------------------- Main panel ---------------------
st.title("πŸ›οΈ Mall Customer Segmentation App")
# --------------------- Tabs ---------------------
tab1, tab2, tab3 = st.tabs(["🏠 About", "πŸ“Š Data Visualization", "πŸ”Ž Predict Cluster"])
# --------------------- About Tab ---------------------
with tab1:
st.header("About This App")
st.markdown("""
## Overview
This app uses **K-Means clustering** to segment mall customers based on their annual income and spending score.
## How It Works
1. **Data Preprocessing**:
- Data is scaled using `StandardScaler` to ensure even distribution.
2. **Clustering**:
- The K-Means algorithm groups customers into 5 clusters.
3. **Prediction**:
- Users can provide input values and the app will predict which cluster they belong to.
## Dataset Information
""")
st.dataframe(df.head())
st.markdown("""
The dataset contains **200 samples** of customer data with the following features:
- **Annual Income (k$)**
- **Spending Score (1-100)**
- **Customer ID, Gender, Age** (used for reference)
""")
# --------------------- Data Visualization Tab ---------------------
with tab2:
st.header("Data Visualization")
# Cluster distribution plot
st.subheader("Cluster Distribution")
fig, ax = plt.subplots()
sns.scatterplot(
x=df['Annual Income (k$)'],
y=df['Spending Score (1-100)'],
hue=df['Cluster Label'],
palette='viridis',
s=100,
alpha=0.7,
ax=ax
)
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.title('Customer Segments')
st.pyplot(fig)
# Feature importance (Cluster centers)
st.subheader("Cluster Centers")
cluster_centers = pd.DataFrame(
scaler.inverse_transform(kmeans.cluster_centers_),
columns=features
)
cluster_centers['Cluster'] = cluster_labels.values()
st.dataframe(cluster_centers)
# --------------------- Predict Cluster Tab ---------------------
with tab3:
st.header("Predict Cluster for Custom Input")
# Display user input
st.subheader("User Input:")
st.write(f"**Annual Income (k$):** {annual_income}")
st.write(f"**Spending Score (1-100):** {spending_score}")
# Display predicted cluster
st.subheader("Predicted Cluster:")
st.write(f"Your input corresponds to **{cluster_labels[user_cluster]}**.")
# Show cluster center distances with explanation
if st.checkbox("Show Cluster Distances"):
st.write("**Distance to Each Cluster:**")
# Display distances in a table
distance_df = pd.DataFrame(distances, index=[f'Cluster {i}' for i in range(len(distances))], columns=["Distance"])
st.dataframe(distance_df)
# πŸ“Š Distance bar plot
st.subheader("Distance to Each Cluster (Graph)")
fig, ax = plt.subplots()
sns.barplot(
x=[f'Cluster {i}' for i in range(len(distances))],
y=distances,
palette='viridis',
ax=ax
)
plt.ylabel("Distance")
plt.title("Distance to Each Cluster")
st.pyplot(fig)
st.markdown("""
**How to Interpret:**
- A **lower distance** means the input is closer to that cluster's center.
- The predicted cluster will have the smallest distance.
""")
# --------------------- Footer ---------------------
st.markdown("---")
st.write("**By: kheejay**")