""" Visualization Module =================== This module handles all visualization components for the customer segmentation analysis. """ # Matplotlib and Seaborn removed to avoid extra dependency # All charts use Plotly for interactive visualization import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import plotly.io as pio import pandas as pd import numpy as np import streamlit as st # Global Plotly template: dark backgrounds to match app theme pio.templates.default = "plotly_dark" pio.templates["plotly_dark"].layout.update( paper_bgcolor="#0F172A", plot_bgcolor="#0F172A", font=dict(color="#E5E7EB") ) # Plot styling handled via Plotly theme settings per figure class Visualizer: """ Handles all visualizations for customer segmentation analysis. """ def __init__(self): # Enhanced color palettes for better visual appeal self.colors = px.colors.qualitative.Set1 # More vibrant colors self.gradient_colors = [ '#FF6B6B', # Coral Red '#4ECDC4', # Turquoise '#45B7D1', # Sky Blue '#96CEB4', # Mint Green '#FFEAA7', # Warm Yellow '#DDA0DD', # Plum '#98D8C8', # Seafoam '#F7DC6F', # Golden Yellow '#BB8FCE', # Lavender '#85C1E9' # Light Blue ] self.modern_colors = [ '#6C5CE7', # Purple '#00B894', # Green '#E17055', # Orange '#0984E3', # Blue '#FDCB6E', # Yellow '#E84393', # Pink '#00CEC9', # Cyan '#A29BFE', # Light Purple '#FD79A8', # Light Pink '#81ECEC' # Light Cyan ] def plot_data_exploration(self, data): """Create comprehensive data exploration plots with enhanced styling.""" if data is None: st.error("❌ No data available for visualization.") return # Debug: Show data info st.info(f"🔍 **Data shape:** {data.shape}") st.info(f"🔍 **Data columns:** {list(data.columns)}") st.subheader("📊 Data Distribution Analysis") # Create subplots for different visualizations col1, col2 = st.columns(2) with col1: # Age distribution with enhanced styling if 'Age' in data.columns: st.write("📊 Creating Age distribution plot...") fig_age = px.histogram( data, x='Age', nbins=20, title='👥 Age Distribution', color_discrete_sequence=[self.gradient_colors[0]] ) fig_age.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', xaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')), yaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')) ) fig_age.update_traces(marker=dict(line=dict(width=1, color='white'))) st.plotly_chart(fig_age, use_container_width=True, theme=None) st.success("✅ Age distribution plot created!") # Income distribution with enhanced styling if 'Annual Income (k$)' in data.columns: st.write("💰 Creating Income distribution plot...") fig_income = px.histogram( data, x='Annual Income (k$)', nbins=20, title='💰 Annual Income Distribution', color_discrete_sequence=[self.gradient_colors[1]] ) fig_income.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', xaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')), yaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')) ) fig_income.update_traces(marker=dict(line=dict(width=1, color='white'))) st.plotly_chart(fig_income, use_container_width=True, theme=None) st.success("✅ Income distribution plot created!") with col2: # Spending Score distribution with enhanced styling if 'Spending Score (1-100)' in data.columns: st.write("🛍️ Creating Spending Score distribution plot...") fig_spending = px.histogram( data, x='Spending Score (1-100)', nbins=20, title='🛍️ Spending Score Distribution', color_discrete_sequence=[self.gradient_colors[2]] ) fig_spending.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', xaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')), yaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')) ) fig_spending.update_traces(marker=dict(line=dict(width=1, color='white'))) st.plotly_chart(fig_spending, use_container_width=True, theme=None) st.success("✅ Spending Score distribution plot created!") # Gender distribution with enhanced styling if 'Gender' in data.columns: gender_counts = data['Gender'].value_counts() fig_gender = px.pie( values=gender_counts.values, names=gender_counts.index, title='👫 Gender Distribution', color_discrete_sequence=self.modern_colors[:len(gender_counts)] ) fig_gender.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A' ) fig_gender.update_traces( textposition='inside', textinfo='percent+label', textfont_size=14, marker=dict(line=dict(color='white', width=2)) ) st.plotly_chart(fig_gender, use_container_width=True) # Enhanced correlation analysis st.subheader("🔗 Feature Correlations") numeric_cols = data.select_dtypes(include=[np.number]).columns if len(numeric_cols) > 1: corr_matrix = data[numeric_cols].corr() fig_corr = px.imshow( corr_matrix, text_auto=True, title='🔗 Feature Correlation Matrix', color_continuous_scale='RdYlBu', aspect='auto' ) fig_corr.update_layout( height=500, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', font=dict(size=12, color='#E5E7EB') ) fig_corr.update_traces( textfont=dict(size=12, color='#E5E7EB'), hoverongaps=False ) st.plotly_chart(fig_corr, theme=None, use_container_width=True) # Enhanced scatter plots st.subheader("🔍 Feature Relationships") col1, col2 = st.columns(2) with col1: if 'Annual Income (k$)' in data.columns and 'Spending Score (1-100)' in data.columns: fig_scatter1 = px.scatter( data, x='Annual Income (k$)', y='Spending Score (1-100)', title='💰 Income vs Spending Score', hover_data=['Age'] if 'Age' in data.columns else None, color_discrete_sequence=[self.modern_colors[3]] ) fig_scatter1.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', xaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')), yaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')) ) fig_scatter1.update_traces( marker=dict(size=8, opacity=0.7, line=dict(width=1, color='white')) ) st.plotly_chart(fig_scatter1, use_container_width=True) with col2: if 'Age' in data.columns and 'Spending Score (1-100)' in data.columns: fig_scatter2 = px.scatter( data, x='Age', y='Spending Score (1-100)', title='👥 Age vs Spending Score', hover_data=['Annual Income (k$)'] if 'Annual Income (k$)' in data.columns else None, color_discrete_sequence=[self.modern_colors[4]] ) fig_scatter2.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', xaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')), yaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')) ) fig_scatter2.update_traces( marker=dict(size=8, opacity=0.7, line=dict(width=1, color='white')) ) st.plotly_chart(fig_scatter2, use_container_width=True) def plot_optimization_results(self, results): """Plot cluster optimization results.""" if results is None: st.error("No optimization results available.") return # Create subplots fig = make_subplots( rows=1, cols=3, subplot_titles=('Elbow Method', 'Silhouette Score', 'Calinski-Harabasz Score'), specs=[[{"secondary_y": False}, {"secondary_y": False}, {"secondary_y": False}]] ) cluster_range = results['cluster_range'] # Elbow method fig.add_trace( go.Scatter(x=cluster_range, y=results['inertias'], mode='lines+markers', name='Inertia', line=dict(color='blue')), row=1, col=1 ) # Silhouette score fig.add_trace( go.Scatter(x=cluster_range, y=results['silhouette_scores'], mode='lines+markers', name='Silhouette Score', line=dict(color='red')), row=1, col=2 ) # Calinski-Harabasz score fig.add_trace( go.Scatter(x=cluster_range, y=results['calinski_scores'], mode='lines+markers', name='Calinski-Harabasz Score', line=dict(color='green')), row=1, col=3 ) # Update layout fig.update_layout( title_text="Cluster Optimization Results", height=400, showlegend=False, paper_bgcolor="#0F172A", plot_bgcolor="#0F172A", font=dict(color="#E5E7EB") ) fig.update_xaxes(title_text="Number of Clusters") fig.update_yaxes(title_text="Inertia", row=1, col=1) fig.update_yaxes(title_text="Silhouette Score", row=1, col=2) fig.update_yaxes(title_text="Calinski-Harabasz Score", row=1, col=3) st.plotly_chart(fig, theme=None, use_container_width=True) # Display optimal results col1, col2, col3 = st.columns(3) with col1: st.metric("Optimal Clusters (Silhouette)", results['optimal_silhouette']) with col2: st.metric("Optimal Clusters (Calinski-Harabasz)", results['optimal_calinski']) with col3: st.metric("Recommended", results['optimal_silhouette']) def plot_clusters(self, data, cluster_labels, algorithm='K-Means', scaler=None, centers=None): """Plot cluster visualizations.""" if data is None or cluster_labels is None: st.error("No data or cluster labels available for visualization.") return # Prepare data with clusters plot_data = data.copy() plot_data['Cluster'] = cluster_labels # Main clustering visualization st.subheader(f"🎯 {algorithm} Clustering Results") col1, col2 = st.columns(2) with col1: if 'Annual Income (k$)' in data.columns and 'Spending Score (1-100)' in data.columns: fig_main = px.scatter(plot_data, x='Annual Income (k$)', y='Spending Score (1-100)', color='Cluster', title=f'{algorithm}: Income vs Spending Score', hover_data=['Age'] if 'Age' in data.columns else None, color_discrete_sequence=self.colors) # Add cluster centers if available if centers is not None and scaler is not None: centers_original = scaler.inverse_transform(centers) centers_df = pd.DataFrame(centers_original, columns=['Annual Income (k$)', 'Spending Score (1-100)']) centers_df['Cluster'] = range(len(centers_df)) fig_main.add_scatter(x=centers_df['Annual Income (k$)'], y=centers_df['Spending Score (1-100)'], mode='markers', marker=dict(symbol='x', size=15, color='red', line=dict(width=2)), name='Centers', showlegend=True) fig_main.update_layout( height=500, paper_bgcolor="#0F172A", plot_bgcolor="#0F172A", font=dict(color="#E5E7EB"), xaxis=dict(gridcolor="rgba(229,231,235,0.12)"), yaxis=dict(gridcolor="rgba(229,231,235,0.12)") ) st.plotly_chart(fig_main, theme=None, use_container_width=True) with col2: if 'Age' in data.columns and 'Spending Score (1-100)' in data.columns: fig_age = px.scatter(plot_data, x='Age', y='Spending Score (1-100)', color='Cluster', title=f'{algorithm}: Age vs Spending Score', color_discrete_sequence=self.colors) fig_age.update_layout( height=500, paper_bgcolor="#0F172A", plot_bgcolor="#0F172A", font=dict(color="#E5E7EB"), xaxis=dict(gridcolor="rgba(229,231,235,0.12)"), yaxis=dict(gridcolor="rgba(229,231,235,0.12)") ) st.plotly_chart(fig_age, theme=None, use_container_width=True) # Enhanced cluster distribution st.subheader("📊 Cluster Distribution") cluster_counts = pd.Series(cluster_labels).value_counts().sort_index() fig_dist = px.bar( x=cluster_counts.index, y=cluster_counts.values, title='📊 Number of Customers per Cluster', labels={'x': 'Cluster', 'y': 'Number of Customers'}, color=cluster_counts.values, color_continuous_scale='Turbo' ) fig_dist.update_layout( height=450, title=dict(font=dict(size=18, color='#E5E7EB'), x=0.5), plot_bgcolor='#0F172A', paper_bgcolor='#0F172A', xaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')), yaxis=dict(gridcolor='rgba(229,231,235,0.12)', title_font=dict(size=14, color='#E5E7EB')) ) fig_dist.update_traces( marker=dict(line=dict(width=1, color='white')) ) st.plotly_chart(fig_dist, theme=None, use_container_width=True) def plot_cluster_analysis(self, analysis_results, algorithm='K-Means'): """Plot detailed cluster analysis with enhanced visualizations.""" if analysis_results is None: st.error("❌ No analysis results available.") return try: data_with_clusters = analysis_results['data_with_clusters'] spending_analysis = analysis_results['spending_analysis'] # COMPLETELY REWRITTEN: Find cluster column with bulletproof detection available_columns = list(data_with_clusters.columns) st.info(f"🔍 **Available columns in data:** {available_columns}") # Find ANY column that contains 'cluster' (case insensitive) cluster_columns = [col for col in available_columns if 'cluster' in col.lower()] st.info(f"🎯 **Found cluster columns:** {cluster_columns}") if not cluster_columns: st.error("❌ No cluster column found in the data!") st.write("Available columns:", available_columns) st.write("Please ensure clustering has been performed first.") return # Use the first cluster column found cluster_col = cluster_columns[0] st.success(f"✅ **Using cluster column:** `{cluster_col}`") # EXTRA SAFETY: Ensure the column actually exists before proceeding if cluster_col not in data_with_clusters.columns: st.error(f"❌ Column `{cluster_col}` not found in data!") st.write("This should not happen. Please report this bug.") return # Create a beautiful header with metrics st.markdown(f"""
Interactive Cluster Visualization & Analysis