"""
Customer Segmentation Streamlit App
==================================
A comprehensive web application for customer segmentation analysis using
K-Means and DBSCAN clustering algorithms.
"""
import streamlit as st
import pandas as pd
import numpy as np
import sys
import os
# Add src to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.data_loader import DataLoader
from src.clustering import ClusteringAnalyzer
from src.visualizations import Visualizer
# Page configuration
st.set_page_config(
page_title="Customer Segmentation Analysis",
page_icon="๐๏ธ",
layout="wide",
initial_sidebar_state="expanded"
)
import plotly.io as pio
pio.templates.default = "plotly_dark"
# Modern Dark Mode Compatible CSS
st.markdown("""
""", unsafe_allow_html=True)
def initialize_session_state():
"""Initialize session state variables."""
if 'data_loader' not in st.session_state:
st.session_state.data_loader = DataLoader()
if 'clustering_analyzer' not in st.session_state:
st.session_state.clustering_analyzer = ClusteringAnalyzer()
if 'visualizer' not in st.session_state:
st.session_state.visualizer = Visualizer()
if 'data_loaded' not in st.session_state:
st.session_state.data_loaded = False
if 'data_preprocessed' not in st.session_state:
st.session_state.data_preprocessed = False
if 'clustering_done' not in st.session_state:
st.session_state.clustering_done = {'kmeans': False, 'dbscan': False}
def main():
"""Main application function."""
initialize_session_state()
# Main header
st.markdown('
๐๏ธ Customer Segmentation Analysis
', unsafe_allow_html=True)
st.markdown("---")
# Tab navigation
tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs([
"๐ Home", "๐ Data Overview", "๐ Data Exploration", "โ๏ธ Preprocessing",
"๐ฏ K-Means", "๐ DBSCAN", "๐ Comparison", "๐ Insights"
])
# Data loading section in sidebar
st.sidebar.markdown("---")
st.sidebar.subheader("๐ Data Management")
# Auto-load dataset on first run
if not st.session_state.data_loaded:
st.session_state.data_loader.load_data()
st.session_state.data_loaded = True
# Show current dataset status
if st.session_state.data_loaded and st.session_state.data_loader.data is not None:
data_info = st.session_state.data_loader.get_data_info()
st.sidebar.success(f"๐ Dataset Loaded")
st.sidebar.info(f"**Rows:** {data_info['shape'][0]}\n**Columns:** {data_info['shape'][1]}")
# Show basic info about the dataset
if 'Annual Income (k$)' in st.session_state.data_loader.data.columns:
st.sidebar.write("**Dataset Type:** Mall Customers")
# File upload option
st.sidebar.markdown("### ๐ Upload Different Dataset")
uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type=['csv'])
if uploaded_file is not None:
try:
data = pd.read_csv(uploaded_file)
st.session_state.data_loader.data = data
st.session_state.data_loaded = True
st.session_state.data_preprocessed = False # Reset preprocessing
st.session_state.clustering_done = {'kmeans': False, 'dbscan': False} # Reset clustering
st.sidebar.success("โ
New file uploaded!")
st.rerun()
except Exception as e:
st.sidebar.error(f"Error loading file: {e}")
# Reload default dataset button
if st.sidebar.button("๐ Reload Default Dataset"):
st.session_state.data_loader.load_data()
st.session_state.data_loaded = True
st.session_state.data_preprocessed = False
st.session_state.clustering_done = {'kmeans': False, 'dbscan': False}
# Clear any cached clustering results
st.session_state.clustering_analyzer = ClusteringAnalyzer()
st.rerun()
# Debug: Clear session state button (remove this after fixing)
if st.sidebar.button("๐งช Clear Session (Debug)"):
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
# Tab content
with tab1:
show_home_page()
with tab2:
show_data_overview()
with tab3:
show_data_exploration()
with tab4:
show_preprocessing()
with tab5:
show_kmeans_clustering()
with tab6:
show_dbscan_clustering()
with tab7:
show_results_comparison()
with tab8:
show_business_insights()
def show_home_page():
"""Display the home page."""
st.markdown('', unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.markdown("""
๐ฏ Project Overview
This application provides a comprehensive customer segmentation analysis using machine learning clustering algorithms.
""", unsafe_allow_html=True)
# Feature overview
st.markdown("### ๐ Features")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("""
**๐ Data Analysis**
- Interactive data exploration
- Statistical summaries
- Correlation analysis
- Missing value detection
""")
with col2:
st.markdown("""
**๐ฏ Clustering Algorithms**
- K-Means clustering
- DBSCAN clustering
- Optimal cluster determination
- Performance metrics
""")
with col3:
st.markdown("""
**๐ Visualizations**
- 2D cluster plots
- Distribution analysis
- Comparative visualizations
- Interactive charts
""")
# Getting started
st.markdown("### ๐ Getting Started")
st.markdown("""
1. **๐ Data Overview**: Check your dataset information and statistics (automatically loaded from `data/Mall_Customers.csv`)
2. **๐ Data Exploration**: Explore distributions, correlations, and relationships
3. **โ๏ธ Preprocessing**: Select features and scale your data for clustering
4. **๐ฏ K-Means**: Apply K-Means clustering with optimal cluster determination
5. **๐ DBSCAN**: Try density-based clustering for comparison
6. **๐ Comparison**: Compare results from both algorithms
7. **๐ Insights**: Get business recommendations for each customer segment
""")
# Quick start note
st.info("""
๐ก **Quick Start**: Your dataset is automatically loaded from the `data/` folder.
Just click on the tabs above to start exploring and clustering your customer data!
""")
# Sample data info
st.markdown("### ๐ Sample Dataset")
st.info("""
The sample dataset simulates mall customer data with the following features:
- **CustomerID**: Unique identifier
- **Gender**: Customer gender (Male/Female)
- **Age**: Customer age (18-70 years)
- **Annual Income (k$)**: Annual income in thousands
- **Spending Score (1-100)**: Mall-assigned spending score
""")
def show_data_overview():
"""Display data overview page."""
st.markdown('', unsafe_allow_html=True)
if not st.session_state.data_loaded:
st.warning("โ ๏ธ Please load data first using the sidebar.")
return
data = st.session_state.data_loader.data
data_info = st.session_state.data_loader.get_data_info()
# Basic information
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Customers", data_info['shape'][0])
with col2:
st.metric("Features", data_info['shape'][1])
with col3:
missing_values = sum(data_info['missing_values'].values())
st.metric("Missing Values", missing_values)
with col4:
numeric_cols = len([col for col, dtype in data_info['dtypes'].items() if dtype in ['int64', 'float64']])
st.metric("Numeric Features", numeric_cols)
# Data preview
st.subheader("๐ Data Preview")
st.dataframe(data.head(10), use_container_width=True)
# Data types and missing values
col1, col2 = st.columns(2)
with col1:
st.subheader("๐ง Data Types")
dtypes_df = pd.DataFrame(list(data_info['dtypes'].items()), columns=['Column', 'Data Type'])
st.dataframe(dtypes_df, use_container_width=True)
with col2:
st.subheader("โ Missing Values")
missing_df = pd.DataFrame(list(data_info['missing_values'].items()), columns=['Column', 'Missing Count'])
missing_df['Missing %'] = (missing_df['Missing Count'] / data_info['shape'][0] * 100).round(2)
st.dataframe(missing_df, use_container_width=True)
# Statistical summary
st.subheader("๐ Statistical Summary")
st.dataframe(data.describe(), use_container_width=True)
def show_data_exploration():
"""Display data exploration page."""
st.markdown('', unsafe_allow_html=True)
if not st.session_state.data_loaded:
st.warning("โ ๏ธ Please load data first using the sidebar.")
return
data = st.session_state.data_loader.data
visualizer = st.session_state.visualizer
# Generate exploration visualizations
visualizer.plot_data_exploration(data)
def show_preprocessing():
"""Display preprocessing page."""
st.markdown('', unsafe_allow_html=True)
if not st.session_state.data_loaded:
st.warning("โ ๏ธ Please load data first using the sidebar.")
return
data = st.session_state.data_loader.data
# Feature selection
st.subheader("๐ฏ Feature Selection")
numeric_columns = data.select_dtypes(include=[np.number]).columns.tolist()
if 'CustomerID' in numeric_columns:
numeric_columns.remove('CustomerID')
selected_features = st.multiselect(
"Select features for clustering:",
numeric_columns,
default=['Annual Income (k$)', 'Spending Score (1-100)'] if all(col in numeric_columns for col in ['Annual Income (k$)', 'Spending Score (1-100)']) else numeric_columns[:2]
)
if len(selected_features) < 2:
st.error("โ ๏ธ Please select at least 2 features for clustering.")
return
# Preprocessing options
st.subheader("๐ง Preprocessing Options")
col1, col2 = st.columns(2)
with col1:
handle_missing = st.selectbox("Handle missing values:", ["Fill with mean", "Drop rows", "No action"])
with col2:
scaling_method = st.selectbox("Scaling method:", ["StandardScaler", "MinMaxScaler", "No scaling"])
# Apply preprocessing
if st.button("๐ Apply Preprocessing"):
scaled_data = st.session_state.data_loader.preprocess_data(selected_features)
if scaled_data is not None:
st.session_state.data_preprocessed = True
# Show preprocessing results
st.success("โ
Data preprocessing completed!")
col1, col2 = st.columns(2)
with col1:
st.subheader("๐ Original Data")
st.dataframe(data[selected_features].head(), use_container_width=True)
with col2:
st.subheader("๐ Scaled Data")
scaled_df = pd.DataFrame(scaled_data, columns=selected_features)
st.dataframe(scaled_df.head(), use_container_width=True)
# Feature statistics
st.subheader("๐ Feature Statistics")
col1, col2 = st.columns(2)
with col1:
st.write("**Original Data Statistics:**")
st.dataframe(data[selected_features].describe(), use_container_width=True)
with col2:
st.write("**Scaled Data Statistics:**")
st.dataframe(scaled_df.describe(), use_container_width=True)
def show_kmeans_clustering():
"""Display K-Means clustering page."""
st.markdown('', unsafe_allow_html=True)
if not st.session_state.data_preprocessed:
st.warning("โ ๏ธ Please preprocess data first.")
return
data_loader = st.session_state.data_loader
clustering_analyzer = st.session_state.clustering_analyzer
visualizer = st.session_state.visualizer
# Optimal cluster determination
st.subheader("๐ Optimal Cluster Determination")
col1, col2 = st.columns([1, 1])
with col1:
max_clusters = st.slider("Maximum clusters to test:", 2, 15, 10)
with col2:
if st.button("๐ Find Optimal Clusters"):
with st.spinner("Finding optimal number of clusters..."):
optimization_results = clustering_analyzer.find_optimal_clusters(data_loader.scaled_data, max_clusters)
if optimization_results:
visualizer.plot_optimization_results(optimization_results)
# K-Means clustering
st.subheader("๐ฏ K-Means Clustering")
col1, col2 = st.columns([1, 1])
with col1:
n_clusters = st.slider("Number of clusters:", 2, 10, clustering_analyzer.optimal_clusters or 5)
with col2:
if st.button("๐ Apply K-Means"):
# Clear any existing clustering results first to avoid column naming issues
clustering_analyzer.cluster_labels = {}
st.session_state.clustering_done = {'kmeans': False, 'dbscan': False}
# Clear any cached data
if hasattr(st.session_state, 'cluster_analysis_cache'):
del st.session_state.cluster_analysis_cache
with st.spinner("๐ Applying K-Means clustering..."):
kmeans_results = clustering_analyzer.apply_kmeans(data_loader.scaled_data, n_clusters)
if kmeans_results:
st.session_state.clustering_done['kmeans'] = True
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Silhouette Score", f"{kmeans_results['silhouette_score']:.3f}")
with col2:
st.metric("Calinski-Harabasz Score", f"{kmeans_results['calinski_score']:.1f}")
with col3:
st.metric("Inertia", f"{kmeans_results['inertia']:.1f}")
# Visualizations
if st.session_state.clustering_done['kmeans']:
feature_data = data_loader.get_feature_data()
kmeans_labels = clustering_analyzer.cluster_labels['kmeans']
visualizer.plot_clusters(
feature_data,
kmeans_labels,
'K-Means',
data_loader.scaler,
clustering_analyzer.kmeans_model.cluster_centers_
)
# Cluster analysis
analysis_results = clustering_analyzer.analyze_clusters(feature_data, 'kmeans')
if analysis_results:
visualizer.plot_cluster_analysis(analysis_results, 'K-Means')
def show_dbscan_clustering():
"""Display DBSCAN clustering page."""
st.markdown('', unsafe_allow_html=True)
if not st.session_state.data_preprocessed:
st.warning("โ ๏ธ Please preprocess data first.")
return
data_loader = st.session_state.data_loader
clustering_analyzer = st.session_state.clustering_analyzer
visualizer = st.session_state.visualizer
# DBSCAN parameters
st.subheader("โ๏ธ DBSCAN Parameters")
col1, col2 = st.columns(2)
with col1:
eps = st.slider("Epsilon (neighborhood distance):", 0.1, 2.0, 0.5, 0.1)
with col2:
min_samples = st.slider("Minimum samples per cluster:", 2, 20, 5)
# Parameter guidance
st.info("""
**Parameter Guidance:**
- **Epsilon**: Maximum distance between points in the same cluster. Smaller values create more clusters.
- **Min Samples**: Minimum number of points required to form a cluster. Higher values create fewer, denser clusters.
""")
# Apply DBSCAN
if st.button("๐ Apply DBSCAN"):
dbscan_results = clustering_analyzer.apply_dbscan(data_loader.scaled_data, eps, min_samples)
if dbscan_results:
st.session_state.clustering_done['dbscan'] = True
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Number of Clusters", dbscan_results['n_clusters'])
with col2:
st.metric("Noise Points", dbscan_results['n_noise'])
with col3:
if 'silhouette_score' in dbscan_results:
st.metric("Silhouette Score", f"{dbscan_results['silhouette_score']:.3f}")
else:
st.metric("Silhouette Score", "N/A")
# Visualizations
if st.session_state.clustering_done['dbscan']:
feature_data = data_loader.get_feature_data()
dbscan_labels = clustering_analyzer.cluster_labels['dbscan']
visualizer.plot_clusters(feature_data, dbscan_labels, 'DBSCAN')
# Cluster analysis
analysis_results = clustering_analyzer.analyze_clusters(feature_data, 'dbscan')
if analysis_results:
visualizer.plot_cluster_analysis(analysis_results, 'DBSCAN')
def show_results_comparison():
"""Display results comparison page."""
st.markdown('', unsafe_allow_html=True)
if not (st.session_state.clustering_done['kmeans'] and st.session_state.clustering_done['dbscan']):
st.warning("โ ๏ธ Please complete both K-Means and DBSCAN clustering first.")
return
data_loader = st.session_state.data_loader
clustering_analyzer = st.session_state.clustering_analyzer
visualizer = st.session_state.visualizer
feature_data = data_loader.get_feature_data()
kmeans_labels = clustering_analyzer.cluster_labels['kmeans']
dbscan_labels = clustering_analyzer.cluster_labels['dbscan']
# Comparison visualization
visualizer.plot_comparison(feature_data, kmeans_labels, dbscan_labels)
# Performance comparison
st.subheader("๐ Performance Metrics Comparison")
# Calculate metrics for both algorithms
kmeans_analysis = clustering_analyzer.analyze_clusters(feature_data, 'kmeans')
dbscan_analysis = clustering_analyzer.analyze_clusters(feature_data, 'dbscan')
comparison_data = {
'Metric': ['Number of Clusters', 'Silhouette Score', 'Noise Points', 'Largest Cluster Size'],
'K-Means': [],
'DBSCAN': []
}
# Number of clusters
comparison_data['K-Means'].append(len(set(kmeans_labels)))
comparison_data['DBSCAN'].append(len(set(dbscan_labels)) - (1 if -1 in dbscan_labels else 0))
# Silhouette scores (if available)
try:
from sklearn.metrics import silhouette_score
kmeans_silhouette = silhouette_score(data_loader.scaled_data, kmeans_labels)
comparison_data['K-Means'].append(f"{kmeans_silhouette:.3f}")
# DBSCAN silhouette (excluding noise)
if -1 in dbscan_labels:
non_noise_mask = dbscan_labels != -1
if np.sum(non_noise_mask) > 1:
dbscan_silhouette = silhouette_score(data_loader.scaled_data[non_noise_mask],
dbscan_labels[non_noise_mask])
comparison_data['DBSCAN'].append(f"{dbscan_silhouette:.3f}")
else:
comparison_data['DBSCAN'].append("N/A")
else:
dbscan_silhouette = silhouette_score(data_loader.scaled_data, dbscan_labels)
comparison_data['DBSCAN'].append(f"{dbscan_silhouette:.3f}")
except:
comparison_data['K-Means'].append("N/A")
comparison_data['DBSCAN'].append("N/A")
# Noise points
comparison_data['K-Means'].append("0")
comparison_data['DBSCAN'].append(str(list(dbscan_labels).count(-1)))
# Largest cluster size
kmeans_counts = pd.Series(kmeans_labels).value_counts()
dbscan_counts = pd.Series(dbscan_labels).value_counts()
comparison_data['K-Means'].append(str(kmeans_counts.max()))
if -1 in dbscan_counts.index:
dbscan_counts = dbscan_counts.drop(-1) # Exclude noise
comparison_data['DBSCAN'].append(str(dbscan_counts.max()) if len(dbscan_counts) > 0 else "0")
comparison_df = pd.DataFrame(comparison_data)
st.dataframe(comparison_df, use_container_width=True)
def show_business_insights():
"""Display business insights page."""
st.markdown('', unsafe_allow_html=True)
if not st.session_state.clustering_done['kmeans']:
st.warning("โ ๏ธ Please complete K-Means clustering first to generate insights.")
return
data_loader = st.session_state.data_loader
clustering_analyzer = st.session_state.clustering_analyzer
feature_data = data_loader.get_feature_data()
# Generate customer profiles
profiles = clustering_analyzer.get_cluster_profiles(feature_data, 'kmeans')
if profiles:
st.subheader("๐ฅ Customer Segment Profiles")
for profile in profiles:
with st.expander(f"๐ท๏ธ Cluster {profile['cluster']} - {profile.get('type', 'Unknown Type')}"):
col1, col2 = st.columns(2)
with col1:
st.markdown(f"**๐ Segment Overview**")
st.write(f"- **Size**: {profile['size']} customers ({profile['percentage']:.1f}%)")
if 'description' in profile:
st.write(f"- **Profile**: {profile['description']}")
if 'avg_age' in profile:
st.write(f"- **Average Age**: {profile['avg_age']:.1f} ยฑ {profile['age_std']:.1f} years")
if 'gender_dist' in profile:
st.write(f"- **Gender Distribution**: {profile['gender_dist']}")
with col2:
st.markdown(f"**๐ฐ Financial Profile**")
if 'avg_income' in profile:
st.write(f"- **Average Income**: ${profile['avg_income']:.1f}k ยฑ ${profile['income_std']:.1f}k")
if 'avg_spending' in profile:
st.write(f"- **Average Spending Score**: {profile['avg_spending']:.1f} ยฑ {profile['spending_std']:.1f}")
# Business recommendations
st.markdown(f"**๐ Recommendations**")
if 'avg_income' in profile and 'avg_spending' in profile:
avg_income = profile['avg_income']
avg_spending = profile['avg_spending']
if avg_income > 70 and avg_spending > 70:
st.write("- Focus on premium products and exclusive services")
st.write("- Implement VIP loyalty programs")
st.write("- Offer personalized shopping experiences")
elif avg_income > 70 and avg_spending < 40:
st.write("- Develop targeted upselling strategies")
st.write("- Showcase value propositions")
st.write("- Create incentive programs to increase spending")
elif avg_income < 40 and avg_spending > 70:
st.write("- Offer value-based products and promotions")
st.write("- Focus on customer retention programs")
st.write("- Provide flexible payment options")
elif avg_income < 40 and avg_spending < 40:
st.write("- Implement engagement and retention strategies")
st.write("- Offer budget-friendly options")
st.write("- Focus on building brand loyalty")
else:
st.write("- Balanced marketing approach")
st.write("- Personalized offers based on preferences")
st.write("- Regular engagement campaigns")
# Overall business strategy
st.subheader("๐ฏ Overall Business Strategy")
col1, col2 = st.columns(2)
with col1:
st.markdown("""
**๐ฏ Marketing Strategies**
- **Segment-specific campaigns**: Tailor marketing messages to each cluster
- **Product positioning**: Align products with cluster preferences
- **Channel optimization**: Use preferred communication channels per segment
- **Pricing strategies**: Implement dynamic pricing based on segment characteristics
""")
with col2:
st.markdown("""
**๐ก Growth Opportunities**
- **Cross-selling**: Identify products popular in high-spending segments
- **Retention programs**: Focus on segments with declining engagement
- **New product development**: Create offerings for underserved segments
- **Customer lifetime value**: Invest more in high-value segments
""")
# Download results
st.subheader("๐พ Download Results")
# Prepare data for download
result_data = feature_data.copy()
result_data['KMeans_Cluster'] = clustering_analyzer.cluster_labels['kmeans']
csv = result_data.to_csv(index=False)
st.download_button(
label="๐ฅ Download Customer Segments (CSV)",
data=csv,
file_name="customer_segments_results.csv",
mime="text/csv"
)
if __name__ == "__main__":
main()