|
|
|
|
|
""" |
|
|
Gradio Web Demo for Customer Segmentation Project |
|
|
|
|
|
A comprehensive interactive interface showcasing: |
|
|
- Tab 1: Dashboard with KPIs and EDA visualizations |
|
|
- Tab 2: Clustering Playground with interactive K selection |
|
|
- Tab 3: Customer DNA analysis with Radar charts |
|
|
- Tab 4: Segment Prediction for new customers |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
from datetime import datetime, timedelta |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, '../src') |
|
|
|
|
|
|
|
|
from utils.data_loader import DataLoader, get_data_loader |
|
|
from utils.clustering_models import ClusteringModels, init_clustering_models |
|
|
from utils.visualizations import ( |
|
|
create_kpi_display, |
|
|
plot_revenue_over_time, |
|
|
plot_hourly_daily_heatmap, |
|
|
plot_elbow_silhouette, |
|
|
plot_clusters_pca_2d, |
|
|
plot_radar_chart, |
|
|
create_cluster_stats_table, |
|
|
) |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_app(): |
|
|
"""Initialize the Gradio app with data and models.""" |
|
|
print("Initializing app...") |
|
|
|
|
|
|
|
|
data_loader = DataLoader("./data/processed") |
|
|
scaled_features = data_loader.scaled_features |
|
|
original_features = data_loader.original_features |
|
|
raw_data = data_loader.raw_data |
|
|
|
|
|
|
|
|
models_dir = "./models" |
|
|
cm = ClusteringModels(scaled_features, original_features, models_dir) |
|
|
|
|
|
|
|
|
if Path(models_dir).exists() and any(Path(models_dir).glob("kmeans_k*.pkl")): |
|
|
print("Loading pre-trained models...") |
|
|
cm.load_models(k_range=range(2, 11)) |
|
|
else: |
|
|
print("Models not found. Training models...") |
|
|
cm.train_models(k_range=range(2, 11)) |
|
|
cm.apply_pca(n_components=None) |
|
|
cm.save_models() |
|
|
|
|
|
|
|
|
if cm.pca_features is None: |
|
|
print("Applying PCA...") |
|
|
cm.apply_pca(n_components=None) |
|
|
|
|
|
init_clustering_models(scaled_features, original_features, models_dir) |
|
|
|
|
|
|
|
|
print("Pre-computing PCA plots for all K values...") |
|
|
pca_plots_cache = {} |
|
|
for k in range(2, 11): |
|
|
if k in cm.cluster_labels: |
|
|
labels = cm.cluster_labels[k] |
|
|
pca_plots_cache[k] = plot_clusters_pca_2d(cm.pca_features, labels, k) |
|
|
print(f" Cached PCA plot for K={k}") |
|
|
|
|
|
print("All PCA plots cached successfully!") |
|
|
|
|
|
return data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache |
|
|
|
|
|
|
|
|
|
|
|
data_loader = None |
|
|
cm = None |
|
|
raw_data = None |
|
|
scaled_features = None |
|
|
original_features = None |
|
|
pca_plots_cache = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_kpi_data(): |
|
|
"""Get KPI metrics.""" |
|
|
return data_loader.get_kpi_metrics() |
|
|
|
|
|
|
|
|
def get_dashboard_plots(): |
|
|
"""Get dashboard plots (cached).""" |
|
|
kpi_metrics = get_kpi_data() |
|
|
kpi_html = create_kpi_display(kpi_metrics) |
|
|
|
|
|
|
|
|
revenue_fig = plot_revenue_over_time(raw_data) |
|
|
|
|
|
|
|
|
heatmap_fig = plot_hourly_daily_heatmap(raw_data) |
|
|
|
|
|
return kpi_html, revenue_fig, heatmap_fig |
|
|
|
|
|
|
|
|
def create_tab1(): |
|
|
"""Create Tab 1: Dashboard Overview.""" |
|
|
with gr.TabItem("Dashboard - Overview"): |
|
|
gr.Markdown("# Data Overview Analysis") |
|
|
|
|
|
|
|
|
kpi_html, revenue_fig, heatmap_fig = get_dashboard_plots() |
|
|
gr.HTML(kpi_html) |
|
|
|
|
|
gr.Markdown("## Revenue Over Time") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
date_start = gr.DateTime( |
|
|
label="From Date", |
|
|
value=raw_data["InvoiceDate"].min() |
|
|
) |
|
|
date_end = gr.DateTime( |
|
|
label="To Date", |
|
|
value=raw_data["InvoiceDate"].max() |
|
|
) |
|
|
|
|
|
revenue_plot = gr.Plot( |
|
|
label="Revenue Chart", |
|
|
value=revenue_fig |
|
|
) |
|
|
|
|
|
|
|
|
def update_revenue_plot(start, end): |
|
|
if start is None or end is None: |
|
|
return revenue_fig |
|
|
return plot_revenue_over_time(raw_data, start, end) |
|
|
|
|
|
date_start.change( |
|
|
fn=update_revenue_plot, |
|
|
inputs=[date_start, date_end], |
|
|
outputs=revenue_plot |
|
|
) |
|
|
date_end.change( |
|
|
fn=update_revenue_plot, |
|
|
inputs=[date_start, date_end], |
|
|
outputs=revenue_plot |
|
|
) |
|
|
|
|
|
gr.Markdown("## Shopping Behavior by Hour and Day") |
|
|
gr.Plot( |
|
|
label="Shopping Activity Heatmap", |
|
|
value=heatmap_fig |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Insights: |
|
|
- **Heatmap** shows shopping patterns by hour (0-23) and day of week |
|
|
- **Revenue Over Time** shows overall sales trend (12 months) |
|
|
- Filter by date range to zoom into peak months (Christmas, etc.) |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_optimal_clusters_data(): |
|
|
"""Get Elbow and Silhouette data (cached).""" |
|
|
k_list = list(range(2, 11)) |
|
|
inertias = cm.inertias |
|
|
silhouette_scores = cm.silhouette_scores |
|
|
|
|
|
return inertias, silhouette_scores, k_list |
|
|
|
|
|
|
|
|
def create_tab2(): |
|
|
"""Create Tab 2: Clustering Playground.""" |
|
|
with gr.TabItem("Clustering - Playground"): |
|
|
gr.Markdown("# Explore K-Means Clustering Algorithm") |
|
|
|
|
|
gr.Markdown(""" |
|
|
Adjust the slider to select different numbers of clusters (K) and see how the algorithm |
|
|
divides customers into different groups. |
|
|
""") |
|
|
|
|
|
|
|
|
inertias, silhouette_scores, k_list = get_optimal_clusters_data() |
|
|
|
|
|
|
|
|
gr.Markdown("## Determine Optimal Number of Clusters") |
|
|
optimal_fig = plot_elbow_silhouette(inertias, silhouette_scores, range(2, 11)) |
|
|
gr.Plot(value=optimal_fig) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Explanation:** |
|
|
- **Elbow Method**: Find the "elbow" point where increasing K doesn't significantly reduce inertia |
|
|
- **Silhouette Score**: Higher is better. Clusters are more distinct when score is high |
|
|
- **Recommendation**: K=3 or K=4 are both good choices |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown("## Visualize Clusters in PCA Space") |
|
|
|
|
|
k_slider = gr.Slider( |
|
|
minimum=2, |
|
|
maximum=10, |
|
|
value=4, |
|
|
step=1, |
|
|
label="Select number of clusters (K)", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
def update_pca_plot(k): |
|
|
"""Update PCA plot based on selected K (from cache).""" |
|
|
if k in pca_plots_cache: |
|
|
return pca_plots_cache[k] |
|
|
return None |
|
|
|
|
|
pca_plot = gr.Plot( |
|
|
label="Scatter Plot: PC1 vs PC2", |
|
|
value=update_pca_plot(4) |
|
|
) |
|
|
|
|
|
k_slider.change( |
|
|
fn=update_pca_plot, |
|
|
inputs=k_slider, |
|
|
outputs=pca_plot |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**How to Use:** |
|
|
- Each **point** represents one customer |
|
|
- **Color** indicates which cluster the customer belongs to |
|
|
- When changing K, clusters will be instantly updated from cache |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_tab3(): |
|
|
"""Create Tab 3: Customer DNA.""" |
|
|
with gr.TabItem("Analysis - Customer DNA"): |
|
|
gr.Markdown("# Deep Analysis: Characteristics of Each Cluster") |
|
|
|
|
|
gr.Markdown(""" |
|
|
Select a cluster to see detailed characteristics of customers in it. |
|
|
The Radar chart shows how this cluster differs from the overall average. |
|
|
""") |
|
|
|
|
|
|
|
|
k_choices = [3, 4] |
|
|
|
|
|
with gr.Row(): |
|
|
k_select = gr.Radio( |
|
|
choices=k_choices, |
|
|
value=4, |
|
|
label="Select Model (K clusters)" |
|
|
) |
|
|
|
|
|
cluster_select = gr.Dropdown( |
|
|
choices=[0, 1, 2, 3], |
|
|
value=0, |
|
|
label="Select Cluster", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
def update_cluster_choices(k): |
|
|
"""Update cluster choices based on selected K.""" |
|
|
return gr.Dropdown( |
|
|
choices=list(range(k)), |
|
|
value=0, |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
k_select.change( |
|
|
fn=update_cluster_choices, |
|
|
inputs=k_select, |
|
|
outputs=cluster_select |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Radar Chart - Comparison with Overall Average") |
|
|
|
|
|
def update_radar_and_stats(k, cluster_idx): |
|
|
"""Update radar chart and statistics.""" |
|
|
cluster_info = cm.get_cluster_info(k) |
|
|
cluster_means = cluster_info["means"] |
|
|
|
|
|
|
|
|
radar_fig = plot_radar_chart(cluster_means, k, cluster_idx=cluster_idx) |
|
|
|
|
|
|
|
|
stats_df = create_cluster_stats_table(cluster_means, k) |
|
|
|
|
|
return radar_fig, stats_df |
|
|
|
|
|
radar_plot = gr.Plot(label="Radar Chart") |
|
|
stats_table = gr.Dataframe(label="Detailed Statistics") |
|
|
|
|
|
|
|
|
k_select.change( |
|
|
fn=update_radar_and_stats, |
|
|
inputs=[k_select, cluster_select], |
|
|
outputs=[radar_plot, stats_table] |
|
|
) |
|
|
cluster_select.change( |
|
|
fn=update_radar_and_stats, |
|
|
inputs=[k_select, cluster_select], |
|
|
outputs=[radar_plot, stats_table] |
|
|
) |
|
|
|
|
|
|
|
|
initial_k = 4 |
|
|
cluster_info = cm.get_cluster_info(initial_k) |
|
|
initial_radar = plot_radar_chart(cluster_info["means"], initial_k, cluster_idx=0) |
|
|
initial_stats = create_cluster_stats_table(cluster_info["means"], initial_k) |
|
|
|
|
|
radar_plot.value = initial_radar |
|
|
stats_table.value = initial_stats |
|
|
|
|
|
gr.Markdown(""" |
|
|
### How to Read Radar Chart: |
|
|
- **Each axis = 1 customer characteristic** (normalized 0-1 scale) |
|
|
- **Further from center = higher value** for that characteristic |
|
|
- **Shape of polygon** represents the cluster's profile |
|
|
- **Compare clusters** by looking at shape and size |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main Gradio app.""" |
|
|
global data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache |
|
|
|
|
|
print("Starting Gradio app initialization...") |
|
|
data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache = initialize_app() |
|
|
print("App initialized successfully!") |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Customer Segmentation Demo" |
|
|
) as demo: |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
# Customer Segmentation - Advanced Analysis |
|
|
|
|
|
Interactive demo showcasing customer clustering analysis with K-Means. |
|
|
Explore data stories, clustering patterns, and predict segments for new customers. |
|
|
""") |
|
|
|
|
|
|
|
|
create_tab1() |
|
|
create_tab2() |
|
|
create_tab3() |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Project:** Advanced Customer Segmentation |
|
|
**Data:** Online Retail (2010-2011) - Customers: 3,920+ - Transactions: 354,000+ |
|
|
**Built from:** Project by Dr.Nguyen Thai Ha |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = main() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |
|
|
|