Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use('Agg') # Non-interactive backend for Gradio | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| import seaborn as sns | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.cluster import KMeans | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| from sklearn.metrics import silhouette_score, silhouette_samples | |
| import joblib | |
| import io, base64 | |
| import json | |
| K_OPTIMAL = 5 | |
| # Load saved artifacts | |
| kmeans_loaded = joblib.load('kmeans_model.pkl') | |
| scaler_loaded = joblib.load('scaler.pkl') | |
| with open('cluster_names.json') as f: | |
| cluster_names_loaded = {int(k): v for k, v in json.load(f).items()} | |
| with open('cluster_insights.json') as f: | |
| insights_loaded = {int(k): v for k, v in json.load(f).items()} | |
| SEGMENT_COLORS = { | |
| 0: '#FF6B6B', 1: '#4ECDC4', 2: '#45B7D1', 3: '#96CEB4', 4: '#FFEAA7' | |
| } | |
| SEGMENT_EMOJIS = {0: 'β οΈ', 1: 'π', 2: 'π§βπΌ', 3: 'π°', 4: 'π'} | |
| def make_radar_chart(cluster_id): | |
| """Generate a radar chart for the predicted cluster.""" | |
| centers = (kmeans_loaded.cluster_centers_ - kmeans_loaded.cluster_centers_.min(axis=0)) / \ | |
| (kmeans_loaded.cluster_centers_.max(axis=0) - kmeans_loaded.cluster_centers_.min(axis=0)) | |
| categories = ['Age', 'Annual Income', 'Spending Score'] | |
| angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist() | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(4, 4), subplot_kw=dict(polar=True)) | |
| fig.patch.set_facecolor('#1a1a2e') | |
| ax.set_facecolor('#16213e') | |
| for i in range(len(centers)): | |
| vals = centers[i].tolist() + [centers[i][0]] | |
| color = SEGMENT_COLORS[i] | |
| lw = 3 if i == cluster_id else 1 | |
| alpha_fill = 0.4 if i == cluster_id else 0.05 | |
| ax.plot(angles, vals, 'o-', linewidth=lw, color=color, | |
| label=cluster_names_loaded[i], alpha=1.0 if i == cluster_id else 0.4) | |
| ax.fill(angles, vals, alpha=alpha_fill, color=color) | |
| ax.set_thetagrids(np.degrees(angles[:-1]), categories, color='white', fontsize=9) | |
| ax.set_ylim(0, 1) | |
| ax.set_title(f'{SEGMENT_EMOJIS[cluster_id]} {cluster_names_loaded[cluster_id]}', | |
| color='white', fontsize=11, fontweight='bold', pad=20) | |
| ax.tick_params(colors='white') | |
| ax.spines['polar'].set_color('#333') | |
| ax.yaxis.set_tick_params(colors='#555') | |
| ax.set_yticklabels([]) | |
| ax.grid(color='#333', linestyle='--', alpha=0.5) | |
| plt.tight_layout() | |
| return fig | |
| def make_comparison_bar(user_vals, cluster_id): | |
| """Bar chart: user values vs cluster centroid.""" | |
| centroid = scaler_loaded.inverse_transform( | |
| kmeans_loaded.cluster_centers_[cluster_id].reshape(1, -1) | |
| )[0] | |
| features = ['Age', 'Annual Income (k$)', 'Spending Score'] | |
| x = np.arange(len(features)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(6, 3.5)) | |
| fig.patch.set_facecolor('#1a1a2e') | |
| ax.set_facecolor('#16213e') | |
| bars1 = ax.bar(x - width/2, user_vals, width, label='You', | |
| color=SEGMENT_COLORS[cluster_id], alpha=0.9, edgecolor='white') | |
| bars2 = ax.bar(x + width/2, centroid, width, label='Cluster Avg', | |
| color='#aaaaaa', alpha=0.6, edgecolor='white') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(features, color='white', fontsize=9) | |
| ax.set_ylabel('Value', color='white') | |
| ax.set_title('You vs Cluster Average', color='white', fontweight='bold') | |
| ax.tick_params(colors='white') | |
| ax.spines[['top','right','left','bottom']].set_color('#333') | |
| ax.yaxis.set_tick_params(colors='white') | |
| ax.legend(facecolor='#222', labelcolor='white', fontsize=9) | |
| for bar in bars1: | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, | |
| f'{bar.get_height():.0f}', ha='center', va='bottom', | |
| color='white', fontsize=8) | |
| plt.tight_layout() | |
| return fig | |
| def predict_segment(age, annual_income, spending_score): | |
| """Core prediction function called by Gradio.""" | |
| user_input = np.array([[age, annual_income, spending_score]]) | |
| user_scaled = scaler_loaded.transform(user_input) | |
| cluster_id = int(kmeans_loaded.predict(user_scaled)[0]) | |
| K_OPTIMAL = 5 | |
| info = insights_loaded[cluster_id] | |
| color = SEGMENT_COLORS[cluster_id] | |
| emoji = SEGMENT_EMOJIS[cluster_id] | |
| # Distance to all centroids | |
| dists = kmeans_loaded.transform(user_scaled)[0] | |
| confidence = 1 - (dists[cluster_id] / dists.sum()) | |
| result_html = f""" | |
| <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); | |
| border-radius: 16px; padding: 24px; color: white; font-family: sans-serif;"> | |
| <div style="text-align:center; margin-bottom: 16px;"> | |
| <div style="font-size: 48px;">{emoji}</div> | |
| <div style="font-size: 26px; font-weight: bold; color: {color};"> | |
| Cluster {cluster_id}: {cluster_names_loaded[cluster_id]} | |
| </div> | |
| <div style="font-size: 13px; color: #aaa; margin-top: 4px;"> | |
| Confidence: {confidence:.1%} | |
| </div> | |
| </div> | |
| <hr style="border-color: #333; margin: 12px 0;"> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;"> | |
| <div style="background: #0f3460; border-radius: 10px; padding: 14px;"> | |
| <div style="font-size: 11px; color: #aaa; text-transform: uppercase; letter-spacing: 1px;">Profile</div> | |
| <div style="margin-top: 6px; font-size: 14px;">{info['desc']}</div> | |
| </div> | |
| <div style="background: #0f3460; border-radius: 10px; padding: 14px;"> | |
| <div style="font-size: 11px; color: #aaa; text-transform: uppercase; letter-spacing: 1px;">π― Recommended Strategy</div> | |
| <div style="margin-top: 6px; font-size: 14px; color: {color};">{info['strategy']}</div> | |
| </div> | |
| </div> | |
| <div style="margin-top: 14px; background: #0f3460; border-radius: 10px; padding: 14px;"> | |
| <div style="font-size: 11px; color: #aaa; margin-bottom: 8px;">π Distance to All Centroids (lower = closer)</div> | |
| {''.join([ | |
| f'<div style="display:flex; align-items:center; margin-bottom:6px;">' + | |
| f'<span style="width:120px; font-size:12px; color:{SEGMENT_COLORS[i]};">{cluster_names_loaded[i][:12]}</span>' + | |
| f'<div style="flex:1; height:8px; background:#1a1a2e; border-radius:4px; overflow:hidden;">' + | |
| f'<div style="height:8px; width:{min(100, dists[i]/max(dists)*100):.0f}%; background:{SEGMENT_COLORS[i]}; border-radius:4px;"></div>' + | |
| f'</div><span style="margin-left:8px; font-size:12px; color:#aaa;">{dists[i]:.2f}</span></div>' | |
| for i in range(K_OPTIMAL) | |
| ])} | |
| </div> | |
| </div> | |
| """ | |
| radar = make_radar_chart(cluster_id) | |
| bar = make_comparison_bar([age, annual_income, spending_score], cluster_id) | |
| return result_html, radar, bar | |
| # βββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| css = """ | |
| .gradio-container { max-width: 960px !important; margin: auto; font-family: 'Segoe UI', sans-serif; } | |
| #title { text-align: center; margin-bottom: 20px; } | |
| .input-panel { background: #16213e; border-radius: 12px; padding: 16px; } | |
| """ | |
| EXAMPLES = [ | |
| [25, 80, 90], | |
| [45, 30, 20], | |
| [35, 60, 55], | |
| [22, 15, 85], | |
| [55, 90, 15], | |
| ] | |
| with gr.Blocks(css=css, theme=gr.themes.Base(primary_hue='blue'), title='Customer Segmentation') as demo: | |
| gr.HTML(""" | |
| <div id="title"> | |
| <h1 style="font-size:2em; margin-bottom:4px;">ποΈ Customer Segmentation</h1> | |
| <p style="color:#888;">K-Means Clustering Β· 5 Customer Segments Β· Real-time Prediction</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_classes='input-panel'): | |
| gr.Markdown('### π Enter Customer Details') | |
| age_inp = gr.Slider(18, 70, value=30, step=1, label='Age') | |
| income_inp = gr.Slider(10, 140, value=60, step=1, label='Annual Income (k$)') | |
| spend_inp = gr.Slider(1, 100, value=50, step=1, label='Spending Score (1β100)') | |
| predict_btn = gr.Button('π Predict Segment', variant='primary', size='lg') | |
| gr.Markdown('### π‘ Try Examples') | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[age_inp, income_inp, spend_inp], | |
| label='Quick Examples' | |
| ) | |
| with gr.Column(scale=2): | |
| result_html = gr.HTML(label='Segment Result') | |
| with gr.Row(): | |
| radar_plot = gr.Plot(label='Cluster Radar Profile') | |
| bar_plot = gr.Plot(label='You vs Cluster Average') | |
| gr.Markdown(""" | |
| --- | |
| **Segments:** β οΈ Cautious Savers Β· π High Potential Β· π§βπΌ Standard Customers Β· π° Budget Shoppers Β· π Premium Loyalists | |
| **Model:** K-Means (K=5, k-means++ init) Β· Scaler: StandardScaler Β· Dataset: Mall Customers | |
| """) | |
| predict_btn.click( | |
| fn=predict_segment, | |
| inputs=[age_inp, income_inp, spend_inp], | |
| outputs=[result_html, radar_plot, bar_plot] | |
| ) | |
| # Launch | |
| demo.launch(share=True, debug=False) |