DataWizard9742 commited on
Commit
a3a4f91
Β·
verified Β·
1 Parent(s): e74c6fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib
3
+ matplotlib.use('Agg') # Non-interactive backend for Gradio
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.cm as cm
9
+ import seaborn as sns
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+
16
+ from sklearn.preprocessing import StandardScaler
17
+ from sklearn.cluster import KMeans
18
+ from sklearn.decomposition import PCA
19
+ from sklearn.manifold import TSNE
20
+ from sklearn.metrics import silhouette_score, silhouette_samples
21
+ import joblib
22
+ import io, base64
23
+
24
+ # Load saved artifacts
25
+ kmeans_loaded = joblib.load('kmeans_model.pkl')
26
+ scaler_loaded = joblib.load('scaler.pkl')
27
+ with open('cluster_names.json') as f:
28
+ cluster_names_loaded = {int(k): v for k, v in json.load(f).items()}
29
+ with open('cluster_insights.json') as f:
30
+ insights_loaded = {int(k): v for k, v in json.load(f).items()}
31
+
32
+ SEGMENT_COLORS = {
33
+ 0: '#FF6B6B', 1: '#4ECDC4', 2: '#45B7D1', 3: '#96CEB4', 4: '#FFEAA7'
34
+ }
35
+ SEGMENT_EMOJIS = {0: '⚠️', 1: 'πŸš€', 2: 'πŸ§‘β€πŸ’Ό', 3: 'πŸ’°', 4: 'πŸ‘‘'}
36
+
37
+
38
+ def make_radar_chart(cluster_id):
39
+ """Generate a radar chart for the predicted cluster."""
40
+ centers = (kmeans_loaded.cluster_centers_ - kmeans_loaded.cluster_centers_.min(axis=0)) / \
41
+ (kmeans_loaded.cluster_centers_.max(axis=0) - kmeans_loaded.cluster_centers_.min(axis=0))
42
+
43
+ categories = ['Age', 'Annual Income', 'Spending Score']
44
+ angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
45
+ angles += angles[:1]
46
+
47
+ fig, ax = plt.subplots(figsize=(4, 4), subplot_kw=dict(polar=True))
48
+ fig.patch.set_facecolor('#1a1a2e')
49
+ ax.set_facecolor('#16213e')
50
+
51
+ for i in range(len(centers)):
52
+ vals = centers[i].tolist() + [centers[i][0]]
53
+ color = SEGMENT_COLORS[i]
54
+ lw = 3 if i == cluster_id else 1
55
+ alpha_fill = 0.4 if i == cluster_id else 0.05
56
+ ax.plot(angles, vals, 'o-', linewidth=lw, color=color,
57
+ label=cluster_names_loaded[i], alpha=1.0 if i == cluster_id else 0.4)
58
+ ax.fill(angles, vals, alpha=alpha_fill, color=color)
59
+
60
+ ax.set_thetagrids(np.degrees(angles[:-1]), categories, color='white', fontsize=9)
61
+ ax.set_ylim(0, 1)
62
+ ax.set_title(f'{SEGMENT_EMOJIS[cluster_id]} {cluster_names_loaded[cluster_id]}',
63
+ color='white', fontsize=11, fontweight='bold', pad=20)
64
+ ax.tick_params(colors='white')
65
+ ax.spines['polar'].set_color('#333')
66
+ ax.yaxis.set_tick_params(colors='#555')
67
+ ax.set_yticklabels([])
68
+ ax.grid(color='#333', linestyle='--', alpha=0.5)
69
+
70
+ plt.tight_layout()
71
+ return fig
72
+
73
+
74
+ def make_comparison_bar(user_vals, cluster_id):
75
+ """Bar chart: user values vs cluster centroid."""
76
+ centroid = scaler_loaded.inverse_transform(
77
+ kmeans_loaded.cluster_centers_[cluster_id].reshape(1, -1)
78
+ )[0]
79
+
80
+ features = ['Age', 'Annual Income (k$)', 'Spending Score']
81
+ x = np.arange(len(features))
82
+ width = 0.35
83
+
84
+ fig, ax = plt.subplots(figsize=(6, 3.5))
85
+ fig.patch.set_facecolor('#1a1a2e')
86
+ ax.set_facecolor('#16213e')
87
+
88
+ bars1 = ax.bar(x - width/2, user_vals, width, label='You',
89
+ color=SEGMENT_COLORS[cluster_id], alpha=0.9, edgecolor='white')
90
+ bars2 = ax.bar(x + width/2, centroid, width, label='Cluster Avg',
91
+ color='#aaaaaa', alpha=0.6, edgecolor='white')
92
+
93
+ ax.set_xticks(x)
94
+ ax.set_xticklabels(features, color='white', fontsize=9)
95
+ ax.set_ylabel('Value', color='white')
96
+ ax.set_title('You vs Cluster Average', color='white', fontweight='bold')
97
+ ax.tick_params(colors='white')
98
+ ax.spines[['top','right','left','bottom']].set_color('#333')
99
+ ax.yaxis.set_tick_params(colors='white')
100
+ ax.legend(facecolor='#222', labelcolor='white', fontsize=9)
101
+
102
+ for bar in bars1:
103
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
104
+ f'{bar.get_height():.0f}', ha='center', va='bottom',
105
+ color='white', fontsize=8)
106
+
107
+ plt.tight_layout()
108
+ return fig
109
+
110
+
111
+ def predict_segment(age, annual_income, spending_score):
112
+ """Core prediction function called by Gradio."""
113
+ user_input = np.array([[age, annual_income, spending_score]])
114
+ user_scaled = scaler_loaded.transform(user_input)
115
+ cluster_id = int(kmeans_loaded.predict(user_scaled)[0])
116
+
117
+ info = insights_loaded[cluster_id]
118
+ color = SEGMENT_COLORS[cluster_id]
119
+ emoji = SEGMENT_EMOJIS[cluster_id]
120
+
121
+ # Distance to all centroids
122
+ dists = kmeans_loaded.transform(user_scaled)[0]
123
+ confidence = 1 - (dists[cluster_id] / dists.sum())
124
+
125
+ result_html = f"""
126
+ <div style="background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
127
+ border-radius: 16px; padding: 24px; color: white; font-family: sans-serif;">
128
+ <div style="text-align:center; margin-bottom: 16px;">
129
+ <div style="font-size: 48px;">{emoji}</div>
130
+ <div style="font-size: 26px; font-weight: bold; color: {color};">
131
+ Cluster {cluster_id}: {cluster_names_loaded[cluster_id]}
132
+ </div>
133
+ <div style="font-size: 13px; color: #aaa; margin-top: 4px;">
134
+ Confidence: {confidence:.1%}
135
+ </div>
136
+ </div>
137
+ <hr style="border-color: #333; margin: 12px 0;">
138
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px;">
139
+ <div style="background: #0f3460; border-radius: 10px; padding: 14px;">
140
+ <div style="font-size: 11px; color: #aaa; text-transform: uppercase; letter-spacing: 1px;">Profile</div>
141
+ <div style="margin-top: 6px; font-size: 14px;">{info['desc']}</div>
142
+ </div>
143
+ <div style="background: #0f3460; border-radius: 10px; padding: 14px;">
144
+ <div style="font-size: 11px; color: #aaa; text-transform: uppercase; letter-spacing: 1px;">🎯 Recommended Strategy</div>
145
+ <div style="margin-top: 6px; font-size: 14px; color: {color};">{info['strategy']}</div>
146
+ </div>
147
+ </div>
148
+ <div style="margin-top: 14px; background: #0f3460; border-radius: 10px; padding: 14px;">
149
+ <div style="font-size: 11px; color: #aaa; margin-bottom: 8px;">πŸ“ Distance to All Centroids (lower = closer)</div>
150
+ {''.join([
151
+ f'<div style="display:flex; align-items:center; margin-bottom:6px;">' +
152
+ f'<span style="width:120px; font-size:12px; color:{SEGMENT_COLORS[i]};">{cluster_names_loaded[i][:12]}</span>' +
153
+ f'<div style="flex:1; height:8px; background:#1a1a2e; border-radius:4px; overflow:hidden;">' +
154
+ f'<div style="height:8px; width:{min(100, dists[i]/max(dists)*100):.0f}%; background:{SEGMENT_COLORS[i]}; border-radius:4px;"></div>' +
155
+ f'</div><span style="margin-left:8px; font-size:12px; color:#aaa;">{dists[i]:.2f}</span></div>'
156
+ for i in range(K_OPTIMAL)
157
+ ])}
158
+ </div>
159
+ </div>
160
+ """
161
+
162
+ radar = make_radar_chart(cluster_id)
163
+ bar = make_comparison_bar([age, annual_income, spending_score], cluster_id)
164
+
165
+ return result_html, radar, bar
166
+
167
+
168
+ # ─── Gradio UI ───────────────────────────────────────────────────────────────
169
+ css = """
170
+ .gradio-container { max-width: 960px !important; margin: auto; font-family: 'Segoe UI', sans-serif; }
171
+ #title { text-align: center; margin-bottom: 20px; }
172
+ .input-panel { background: #16213e; border-radius: 12px; padding: 16px; }
173
+ """
174
+
175
+ EXAMPLES = [
176
+ [25, 80, 90],
177
+ [45, 30, 20],
178
+ [35, 60, 55],
179
+ [22, 15, 85],
180
+ [55, 90, 15],
181
+ ]
182
+
183
+ with gr.Blocks(css=css, theme=gr.themes.Base(primary_hue='blue'), title='Customer Segmentation') as demo:
184
+
185
+ gr.HTML("""
186
+ <div id="title">
187
+ <h1 style="font-size:2em; margin-bottom:4px;">πŸ›οΈ Customer Segmentation</h1>
188
+ <p style="color:#888;">K-Means Clustering Β· 5 Customer Segments Β· Real-time Prediction</p>
189
+ </div>
190
+ """)
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=1, elem_classes='input-panel'):
194
+ gr.Markdown('### πŸ“ Enter Customer Details')
195
+ age_inp = gr.Slider(18, 70, value=30, step=1, label='Age')
196
+ income_inp = gr.Slider(10, 140, value=60, step=1, label='Annual Income (k$)')
197
+ spend_inp = gr.Slider(1, 100, value=50, step=1, label='Spending Score (1–100)')
198
+ predict_btn = gr.Button('πŸ” Predict Segment', variant='primary', size='lg')
199
+
200
+ gr.Markdown('### πŸ’‘ Try Examples')
201
+ gr.Examples(
202
+ examples=EXAMPLES,
203
+ inputs=[age_inp, income_inp, spend_inp],
204
+ label='Quick Examples'
205
+ )
206
+
207
+ with gr.Column(scale=2):
208
+ result_html = gr.HTML(label='Segment Result')
209
+ with gr.Row():
210
+ radar_plot = gr.Plot(label='Cluster Radar Profile')
211
+ bar_plot = gr.Plot(label='You vs Cluster Average')
212
+
213
+ gr.Markdown("""
214
+ ---
215
+ **Segments:** ⚠️ Cautious Savers Β· πŸš€ High Potential Β· πŸ§‘β€πŸ’Ό Standard Customers Β· πŸ’° Budget Shoppers Β· πŸ‘‘ Premium Loyalists
216
+ **Model:** K-Means (K=5, k-means++ init) Β· Scaler: StandardScaler Β· Dataset: Mall Customers
217
+ """)
218
+
219
+ predict_btn.click(
220
+ fn=predict_segment,
221
+ inputs=[age_inp, income_inp, spend_inp],
222
+ outputs=[result_html, radar_plot, bar_plot]
223
+ )
224
+
225
+ # Launch
226
+ demo.launch(share=True, debug=False)