""" Customer Segmentation Streamlit App ================================== A comprehensive web application for customer segmentation analysis using K-Means and DBSCAN clustering algorithms. """ import streamlit as st import pandas as pd import numpy as np import sys import os # Add src to path for imports sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from src.data_loader import DataLoader from src.clustering import ClusteringAnalyzer from src.visualizations import Visualizer # Page configuration st.set_page_config( page_title="Customer Segmentation Analysis", page_icon="๐Ÿ›๏ธ", layout="wide", initial_sidebar_state="expanded" ) import plotly.io as pio pio.templates.default = "plotly_dark" # Modern Dark Mode Compatible CSS st.markdown(""" """, unsafe_allow_html=True) def initialize_session_state(): """Initialize session state variables.""" if 'data_loader' not in st.session_state: st.session_state.data_loader = DataLoader() if 'clustering_analyzer' not in st.session_state: st.session_state.clustering_analyzer = ClusteringAnalyzer() if 'visualizer' not in st.session_state: st.session_state.visualizer = Visualizer() if 'data_loaded' not in st.session_state: st.session_state.data_loaded = False if 'data_preprocessed' not in st.session_state: st.session_state.data_preprocessed = False if 'clustering_done' not in st.session_state: st.session_state.clustering_done = {'kmeans': False, 'dbscan': False} def main(): """Main application function.""" initialize_session_state() # Main header st.markdown('

๐Ÿ›๏ธ Customer Segmentation Analysis

', unsafe_allow_html=True) st.markdown("---") # Tab navigation tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs([ "๐Ÿ  Home", "๐Ÿ“Š Data Overview", "๐Ÿ” Data Exploration", "โš™๏ธ Preprocessing", "๐ŸŽฏ K-Means", "๐ŸŒŸ DBSCAN", "๐Ÿ“ˆ Comparison", "๐Ÿ“‹ Insights" ]) # Data loading section in sidebar st.sidebar.markdown("---") st.sidebar.subheader("๐Ÿ“‚ Data Management") # Auto-load dataset on first run if not st.session_state.data_loaded: st.session_state.data_loader.load_data() st.session_state.data_loaded = True # Show current dataset status if st.session_state.data_loaded and st.session_state.data_loader.data is not None: data_info = st.session_state.data_loader.get_data_info() st.sidebar.success(f"๐Ÿ“Š Dataset Loaded") st.sidebar.info(f"**Rows:** {data_info['shape'][0]}\n**Columns:** {data_info['shape'][1]}") # Show basic info about the dataset if 'Annual Income (k$)' in st.session_state.data_loader.data.columns: st.sidebar.write("**Dataset Type:** Mall Customers") # File upload option st.sidebar.markdown("### ๐Ÿ“ Upload Different Dataset") uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type=['csv']) if uploaded_file is not None: try: data = pd.read_csv(uploaded_file) st.session_state.data_loader.data = data st.session_state.data_loaded = True st.session_state.data_preprocessed = False # Reset preprocessing st.session_state.clustering_done = {'kmeans': False, 'dbscan': False} # Reset clustering st.sidebar.success("โœ… New file uploaded!") st.rerun() except Exception as e: st.sidebar.error(f"Error loading file: {e}") # Reload default dataset button if st.sidebar.button("๐Ÿ”„ Reload Default Dataset"): st.session_state.data_loader.load_data() st.session_state.data_loaded = True st.session_state.data_preprocessed = False st.session_state.clustering_done = {'kmeans': False, 'dbscan': False} # Clear any cached clustering results st.session_state.clustering_analyzer = ClusteringAnalyzer() st.rerun() # Debug: Clear session state button (remove this after fixing) if st.sidebar.button("๐Ÿงช Clear Session (Debug)"): for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() # Tab content with tab1: show_home_page() with tab2: show_data_overview() with tab3: show_data_exploration() with tab4: show_preprocessing() with tab5: show_kmeans_clustering() with tab6: show_dbscan_clustering() with tab7: show_results_comparison() with tab8: show_business_insights() def show_home_page(): """Display the home page.""" st.markdown('

Welcome to Customer Segmentation Analysis

', unsafe_allow_html=True) col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.markdown("""

๐ŸŽฏ Project Overview

This application provides a comprehensive customer segmentation analysis using machine learning clustering algorithms.

""", unsafe_allow_html=True) # Feature overview st.markdown("### ๐Ÿš€ Features") col1, col2, col3 = st.columns(3) with col1: st.markdown(""" **๐Ÿ“Š Data Analysis** - Interactive data exploration - Statistical summaries - Correlation analysis - Missing value detection """) with col2: st.markdown(""" **๐ŸŽฏ Clustering Algorithms** - K-Means clustering - DBSCAN clustering - Optimal cluster determination - Performance metrics """) with col3: st.markdown(""" **๐Ÿ“ˆ Visualizations** - 2D cluster plots - Distribution analysis - Comparative visualizations - Interactive charts """) # Getting started st.markdown("### ๐Ÿ Getting Started") st.markdown(""" 1. **๐Ÿ“Š Data Overview**: Check your dataset information and statistics (automatically loaded from `data/Mall_Customers.csv`) 2. **๐Ÿ” Data Exploration**: Explore distributions, correlations, and relationships 3. **โš™๏ธ Preprocessing**: Select features and scale your data for clustering 4. **๐ŸŽฏ K-Means**: Apply K-Means clustering with optimal cluster determination 5. **๐ŸŒŸ DBSCAN**: Try density-based clustering for comparison 6. **๐Ÿ“ˆ Comparison**: Compare results from both algorithms 7. **๐Ÿ“‹ Insights**: Get business recommendations for each customer segment """) # Quick start note st.info(""" ๐Ÿ’ก **Quick Start**: Your dataset is automatically loaded from the `data/` folder. Just click on the tabs above to start exploring and clustering your customer data! """) # Sample data info st.markdown("### ๐Ÿ“‹ Sample Dataset") st.info(""" The sample dataset simulates mall customer data with the following features: - **CustomerID**: Unique identifier - **Gender**: Customer gender (Male/Female) - **Age**: Customer age (18-70 years) - **Annual Income (k$)**: Annual income in thousands - **Spending Score (1-100)**: Mall-assigned spending score """) def show_data_overview(): """Display data overview page.""" st.markdown('

๐Ÿ“Š Data Overview

', unsafe_allow_html=True) if not st.session_state.data_loaded: st.warning("โš ๏ธ Please load data first using the sidebar.") return data = st.session_state.data_loader.data data_info = st.session_state.data_loader.get_data_info() # Basic information col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Customers", data_info['shape'][0]) with col2: st.metric("Features", data_info['shape'][1]) with col3: missing_values = sum(data_info['missing_values'].values()) st.metric("Missing Values", missing_values) with col4: numeric_cols = len([col for col, dtype in data_info['dtypes'].items() if dtype in ['int64', 'float64']]) st.metric("Numeric Features", numeric_cols) # Data preview st.subheader("๐Ÿ“‹ Data Preview") st.dataframe(data.head(10), use_container_width=True) # Data types and missing values col1, col2 = st.columns(2) with col1: st.subheader("๐Ÿ”ง Data Types") dtypes_df = pd.DataFrame(list(data_info['dtypes'].items()), columns=['Column', 'Data Type']) st.dataframe(dtypes_df, use_container_width=True) with col2: st.subheader("โ“ Missing Values") missing_df = pd.DataFrame(list(data_info['missing_values'].items()), columns=['Column', 'Missing Count']) missing_df['Missing %'] = (missing_df['Missing Count'] / data_info['shape'][0] * 100).round(2) st.dataframe(missing_df, use_container_width=True) # Statistical summary st.subheader("๐Ÿ“ˆ Statistical Summary") st.dataframe(data.describe(), use_container_width=True) def show_data_exploration(): """Display data exploration page.""" st.markdown('

๐Ÿ” Data Exploration

', unsafe_allow_html=True) if not st.session_state.data_loaded: st.warning("โš ๏ธ Please load data first using the sidebar.") return data = st.session_state.data_loader.data visualizer = st.session_state.visualizer # Generate exploration visualizations visualizer.plot_data_exploration(data) def show_preprocessing(): """Display preprocessing page.""" st.markdown('

โš™๏ธ Data Preprocessing

', unsafe_allow_html=True) if not st.session_state.data_loaded: st.warning("โš ๏ธ Please load data first using the sidebar.") return data = st.session_state.data_loader.data # Feature selection st.subheader("๐ŸŽฏ Feature Selection") numeric_columns = data.select_dtypes(include=[np.number]).columns.tolist() if 'CustomerID' in numeric_columns: numeric_columns.remove('CustomerID') selected_features = st.multiselect( "Select features for clustering:", numeric_columns, default=['Annual Income (k$)', 'Spending Score (1-100)'] if all(col in numeric_columns for col in ['Annual Income (k$)', 'Spending Score (1-100)']) else numeric_columns[:2] ) if len(selected_features) < 2: st.error("โš ๏ธ Please select at least 2 features for clustering.") return # Preprocessing options st.subheader("๐Ÿ”ง Preprocessing Options") col1, col2 = st.columns(2) with col1: handle_missing = st.selectbox("Handle missing values:", ["Fill with mean", "Drop rows", "No action"]) with col2: scaling_method = st.selectbox("Scaling method:", ["StandardScaler", "MinMaxScaler", "No scaling"]) # Apply preprocessing if st.button("๐Ÿš€ Apply Preprocessing"): scaled_data = st.session_state.data_loader.preprocess_data(selected_features) if scaled_data is not None: st.session_state.data_preprocessed = True # Show preprocessing results st.success("โœ… Data preprocessing completed!") col1, col2 = st.columns(2) with col1: st.subheader("๐Ÿ“Š Original Data") st.dataframe(data[selected_features].head(), use_container_width=True) with col2: st.subheader("๐Ÿ”„ Scaled Data") scaled_df = pd.DataFrame(scaled_data, columns=selected_features) st.dataframe(scaled_df.head(), use_container_width=True) # Feature statistics st.subheader("๐Ÿ“ˆ Feature Statistics") col1, col2 = st.columns(2) with col1: st.write("**Original Data Statistics:**") st.dataframe(data[selected_features].describe(), use_container_width=True) with col2: st.write("**Scaled Data Statistics:**") st.dataframe(scaled_df.describe(), use_container_width=True) def show_kmeans_clustering(): """Display K-Means clustering page.""" st.markdown('

๐ŸŽฏ K-Means Clustering

', unsafe_allow_html=True) if not st.session_state.data_preprocessed: st.warning("โš ๏ธ Please preprocess data first.") return data_loader = st.session_state.data_loader clustering_analyzer = st.session_state.clustering_analyzer visualizer = st.session_state.visualizer # Optimal cluster determination st.subheader("๐Ÿ” Optimal Cluster Determination") col1, col2 = st.columns([1, 1]) with col1: max_clusters = st.slider("Maximum clusters to test:", 2, 15, 10) with col2: if st.button("๐Ÿ” Find Optimal Clusters"): with st.spinner("Finding optimal number of clusters..."): optimization_results = clustering_analyzer.find_optimal_clusters(data_loader.scaled_data, max_clusters) if optimization_results: visualizer.plot_optimization_results(optimization_results) # K-Means clustering st.subheader("๐ŸŽฏ K-Means Clustering") col1, col2 = st.columns([1, 1]) with col1: n_clusters = st.slider("Number of clusters:", 2, 10, clustering_analyzer.optimal_clusters or 5) with col2: if st.button("๐Ÿš€ Apply K-Means"): # Clear any existing clustering results first to avoid column naming issues clustering_analyzer.cluster_labels = {} st.session_state.clustering_done = {'kmeans': False, 'dbscan': False} # Clear any cached data if hasattr(st.session_state, 'cluster_analysis_cache'): del st.session_state.cluster_analysis_cache with st.spinner("๐Ÿ”„ Applying K-Means clustering..."): kmeans_results = clustering_analyzer.apply_kmeans(data_loader.scaled_data, n_clusters) if kmeans_results: st.session_state.clustering_done['kmeans'] = True # Display metrics col1, col2, col3 = st.columns(3) with col1: st.metric("Silhouette Score", f"{kmeans_results['silhouette_score']:.3f}") with col2: st.metric("Calinski-Harabasz Score", f"{kmeans_results['calinski_score']:.1f}") with col3: st.metric("Inertia", f"{kmeans_results['inertia']:.1f}") # Visualizations if st.session_state.clustering_done['kmeans']: feature_data = data_loader.get_feature_data() kmeans_labels = clustering_analyzer.cluster_labels['kmeans'] visualizer.plot_clusters( feature_data, kmeans_labels, 'K-Means', data_loader.scaler, clustering_analyzer.kmeans_model.cluster_centers_ ) # Cluster analysis analysis_results = clustering_analyzer.analyze_clusters(feature_data, 'kmeans') if analysis_results: visualizer.plot_cluster_analysis(analysis_results, 'K-Means') def show_dbscan_clustering(): """Display DBSCAN clustering page.""" st.markdown('

๐ŸŒŸ DBSCAN Clustering

', unsafe_allow_html=True) if not st.session_state.data_preprocessed: st.warning("โš ๏ธ Please preprocess data first.") return data_loader = st.session_state.data_loader clustering_analyzer = st.session_state.clustering_analyzer visualizer = st.session_state.visualizer # DBSCAN parameters st.subheader("โš™๏ธ DBSCAN Parameters") col1, col2 = st.columns(2) with col1: eps = st.slider("Epsilon (neighborhood distance):", 0.1, 2.0, 0.5, 0.1) with col2: min_samples = st.slider("Minimum samples per cluster:", 2, 20, 5) # Parameter guidance st.info(""" **Parameter Guidance:** - **Epsilon**: Maximum distance between points in the same cluster. Smaller values create more clusters. - **Min Samples**: Minimum number of points required to form a cluster. Higher values create fewer, denser clusters. """) # Apply DBSCAN if st.button("๐Ÿš€ Apply DBSCAN"): dbscan_results = clustering_analyzer.apply_dbscan(data_loader.scaled_data, eps, min_samples) if dbscan_results: st.session_state.clustering_done['dbscan'] = True # Display metrics col1, col2, col3 = st.columns(3) with col1: st.metric("Number of Clusters", dbscan_results['n_clusters']) with col2: st.metric("Noise Points", dbscan_results['n_noise']) with col3: if 'silhouette_score' in dbscan_results: st.metric("Silhouette Score", f"{dbscan_results['silhouette_score']:.3f}") else: st.metric("Silhouette Score", "N/A") # Visualizations if st.session_state.clustering_done['dbscan']: feature_data = data_loader.get_feature_data() dbscan_labels = clustering_analyzer.cluster_labels['dbscan'] visualizer.plot_clusters(feature_data, dbscan_labels, 'DBSCAN') # Cluster analysis analysis_results = clustering_analyzer.analyze_clusters(feature_data, 'dbscan') if analysis_results: visualizer.plot_cluster_analysis(analysis_results, 'DBSCAN') def show_results_comparison(): """Display results comparison page.""" st.markdown('

๐Ÿ“ˆ Results Comparison

', unsafe_allow_html=True) if not (st.session_state.clustering_done['kmeans'] and st.session_state.clustering_done['dbscan']): st.warning("โš ๏ธ Please complete both K-Means and DBSCAN clustering first.") return data_loader = st.session_state.data_loader clustering_analyzer = st.session_state.clustering_analyzer visualizer = st.session_state.visualizer feature_data = data_loader.get_feature_data() kmeans_labels = clustering_analyzer.cluster_labels['kmeans'] dbscan_labels = clustering_analyzer.cluster_labels['dbscan'] # Comparison visualization visualizer.plot_comparison(feature_data, kmeans_labels, dbscan_labels) # Performance comparison st.subheader("๐Ÿ“Š Performance Metrics Comparison") # Calculate metrics for both algorithms kmeans_analysis = clustering_analyzer.analyze_clusters(feature_data, 'kmeans') dbscan_analysis = clustering_analyzer.analyze_clusters(feature_data, 'dbscan') comparison_data = { 'Metric': ['Number of Clusters', 'Silhouette Score', 'Noise Points', 'Largest Cluster Size'], 'K-Means': [], 'DBSCAN': [] } # Number of clusters comparison_data['K-Means'].append(len(set(kmeans_labels))) comparison_data['DBSCAN'].append(len(set(dbscan_labels)) - (1 if -1 in dbscan_labels else 0)) # Silhouette scores (if available) try: from sklearn.metrics import silhouette_score kmeans_silhouette = silhouette_score(data_loader.scaled_data, kmeans_labels) comparison_data['K-Means'].append(f"{kmeans_silhouette:.3f}") # DBSCAN silhouette (excluding noise) if -1 in dbscan_labels: non_noise_mask = dbscan_labels != -1 if np.sum(non_noise_mask) > 1: dbscan_silhouette = silhouette_score(data_loader.scaled_data[non_noise_mask], dbscan_labels[non_noise_mask]) comparison_data['DBSCAN'].append(f"{dbscan_silhouette:.3f}") else: comparison_data['DBSCAN'].append("N/A") else: dbscan_silhouette = silhouette_score(data_loader.scaled_data, dbscan_labels) comparison_data['DBSCAN'].append(f"{dbscan_silhouette:.3f}") except: comparison_data['K-Means'].append("N/A") comparison_data['DBSCAN'].append("N/A") # Noise points comparison_data['K-Means'].append("0") comparison_data['DBSCAN'].append(str(list(dbscan_labels).count(-1))) # Largest cluster size kmeans_counts = pd.Series(kmeans_labels).value_counts() dbscan_counts = pd.Series(dbscan_labels).value_counts() comparison_data['K-Means'].append(str(kmeans_counts.max())) if -1 in dbscan_counts.index: dbscan_counts = dbscan_counts.drop(-1) # Exclude noise comparison_data['DBSCAN'].append(str(dbscan_counts.max()) if len(dbscan_counts) > 0 else "0") comparison_df = pd.DataFrame(comparison_data) st.dataframe(comparison_df, use_container_width=True) def show_business_insights(): """Display business insights page.""" st.markdown('

๐Ÿ“‹ Business Insights

', unsafe_allow_html=True) if not st.session_state.clustering_done['kmeans']: st.warning("โš ๏ธ Please complete K-Means clustering first to generate insights.") return data_loader = st.session_state.data_loader clustering_analyzer = st.session_state.clustering_analyzer feature_data = data_loader.get_feature_data() # Generate customer profiles profiles = clustering_analyzer.get_cluster_profiles(feature_data, 'kmeans') if profiles: st.subheader("๐Ÿ‘ฅ Customer Segment Profiles") for profile in profiles: with st.expander(f"๐Ÿท๏ธ Cluster {profile['cluster']} - {profile.get('type', 'Unknown Type')}"): col1, col2 = st.columns(2) with col1: st.markdown(f"**๐Ÿ“Š Segment Overview**") st.write(f"- **Size**: {profile['size']} customers ({profile['percentage']:.1f}%)") if 'description' in profile: st.write(f"- **Profile**: {profile['description']}") if 'avg_age' in profile: st.write(f"- **Average Age**: {profile['avg_age']:.1f} ยฑ {profile['age_std']:.1f} years") if 'gender_dist' in profile: st.write(f"- **Gender Distribution**: {profile['gender_dist']}") with col2: st.markdown(f"**๐Ÿ’ฐ Financial Profile**") if 'avg_income' in profile: st.write(f"- **Average Income**: ${profile['avg_income']:.1f}k ยฑ ${profile['income_std']:.1f}k") if 'avg_spending' in profile: st.write(f"- **Average Spending Score**: {profile['avg_spending']:.1f} ยฑ {profile['spending_std']:.1f}") # Business recommendations st.markdown(f"**๐Ÿ“ˆ Recommendations**") if 'avg_income' in profile and 'avg_spending' in profile: avg_income = profile['avg_income'] avg_spending = profile['avg_spending'] if avg_income > 70 and avg_spending > 70: st.write("- Focus on premium products and exclusive services") st.write("- Implement VIP loyalty programs") st.write("- Offer personalized shopping experiences") elif avg_income > 70 and avg_spending < 40: st.write("- Develop targeted upselling strategies") st.write("- Showcase value propositions") st.write("- Create incentive programs to increase spending") elif avg_income < 40 and avg_spending > 70: st.write("- Offer value-based products and promotions") st.write("- Focus on customer retention programs") st.write("- Provide flexible payment options") elif avg_income < 40 and avg_spending < 40: st.write("- Implement engagement and retention strategies") st.write("- Offer budget-friendly options") st.write("- Focus on building brand loyalty") else: st.write("- Balanced marketing approach") st.write("- Personalized offers based on preferences") st.write("- Regular engagement campaigns") # Overall business strategy st.subheader("๐ŸŽฏ Overall Business Strategy") col1, col2 = st.columns(2) with col1: st.markdown(""" **๐ŸŽฏ Marketing Strategies** - **Segment-specific campaigns**: Tailor marketing messages to each cluster - **Product positioning**: Align products with cluster preferences - **Channel optimization**: Use preferred communication channels per segment - **Pricing strategies**: Implement dynamic pricing based on segment characteristics """) with col2: st.markdown(""" **๐Ÿ’ก Growth Opportunities** - **Cross-selling**: Identify products popular in high-spending segments - **Retention programs**: Focus on segments with declining engagement - **New product development**: Create offerings for underserved segments - **Customer lifetime value**: Invest more in high-value segments """) # Download results st.subheader("๐Ÿ’พ Download Results") # Prepare data for download result_data = feature_data.copy() result_data['KMeans_Cluster'] = clustering_analyzer.cluster_labels['kmeans'] csv = result_data.to_csv(index=False) st.download_button( label="๐Ÿ“ฅ Download Customer Segments (CSV)", data=csv, file_name="customer_segments_results.csv", mime="text/csv" ) if __name__ == "__main__": main()